mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 21:37:05 +08:00
[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
303d44790a
commit
0f6d7a9a34
@ -146,7 +146,9 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
command: pytest -v -s test_regression.py
|
||||
commands:
|
||||
- pip install modelscope
|
||||
- pytest -v -s test_regression.py
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
|
||||
@ -12,201 +12,249 @@ Alongside each architecture, we include some popular models that use it.
|
||||
Decoder-only Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
.. list-table::
|
||||
:widths: 25 25 50 5
|
||||
:widths: 25 25 50 5 5
|
||||
:header-rows: 1
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`AquilaForCausalLM`
|
||||
- Aquila, Aquila2
|
||||
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`ArcticForCausalLM`
|
||||
- Arctic
|
||||
- :code:`Snowflake/snowflake-arctic-base`, :code:`Snowflake/snowflake-arctic-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`BaiChuanForCausalLM`
|
||||
- Baichuan2, Baichuan
|
||||
- :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`BloomForCausalLM`
|
||||
- BLOOM, BLOOMZ, BLOOMChat
|
||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`ChatGLMModel`
|
||||
- ChatGLM
|
||||
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`CohereForCausalLM`
|
||||
- Command-R
|
||||
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`DbrxForCausalLM`
|
||||
- DBRX
|
||||
- :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`DeciLMForCausalLM`
|
||||
- DeciLM
|
||||
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`DeepseekForCausalLM`
|
||||
- DeepSeek
|
||||
- :code:`deepseek-ai/deepseek-llm-67b-base`, :code:`deepseek-ai/deepseek-llm-7b-chat` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`DeepseekV2ForCausalLM`
|
||||
- DeepSeek-V2
|
||||
- :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`ExaoneForCausalLM`
|
||||
- EXAONE-3
|
||||
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`FalconForCausalLM`
|
||||
- Falcon
|
||||
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`GemmaForCausalLM`
|
||||
- Gemma
|
||||
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`Gemma2ForCausalLM`
|
||||
- Gemma2
|
||||
- :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`GPT2LMHeadModel`
|
||||
- GPT-2
|
||||
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`GPTBigCodeForCausalLM`
|
||||
- StarCoder, SantaCoder, WizardCoder
|
||||
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`GPTJForCausalLM`
|
||||
- GPT-J
|
||||
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`GPTNeoXForCausalLM`
|
||||
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
||||
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`GraniteForCausalLM`
|
||||
- PowerLM
|
||||
- :code:`ibm/PowerLM-3b` etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`GraniteMoeForCausalLM`
|
||||
- PowerMoE
|
||||
- :code:`ibm/PowerMoE-3b` etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`InternLMForCausalLM`
|
||||
- InternLM
|
||||
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`InternLM2ForCausalLM`
|
||||
- InternLM2
|
||||
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`JAISLMHeadModel`
|
||||
- Jais
|
||||
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`JambaForCausalLM`
|
||||
- Jamba
|
||||
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
|
||||
- ✅︎
|
||||
-
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
|
||||
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MiniCPMForCausalLM`
|
||||
- MiniCPM
|
||||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MiniCPM3ForCausalLM`
|
||||
- MiniCPM3
|
||||
- :code:`openbmb/MiniCPM3-4B`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MistralForCausalLM`
|
||||
- Mistral, Mistral-Instruct
|
||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MixtralForCausalLM`
|
||||
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`NemotronForCausalLM`
|
||||
- Nemotron-3, Nemotron-4, Minitron
|
||||
- :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc.
|
||||
- ✅︎
|
||||
* - :code:`OLMoEForCausalLM`
|
||||
- OLMoE
|
||||
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`OLMoForCausalLM`
|
||||
- OLMo
|
||||
- :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`OLMoEForCausalLM`
|
||||
- OLMoE
|
||||
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`OPTForCausalLM`
|
||||
- OPT, OPT-IML
|
||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`OrionForCausalLM`
|
||||
- Orion
|
||||
- :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`PhiForCausalLM`
|
||||
- Phi
|
||||
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`Phi3ForCausalLM`
|
||||
- Phi-3
|
||||
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, :code:`microsoft/Phi-3-medium-128k-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`Phi3SmallForCausalLM`
|
||||
- Phi-3-Small
|
||||
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`PhiMoEForCausalLM`
|
||||
- Phi-3.5-MoE
|
||||
- :code:`microsoft/Phi-3.5-MoE-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`PersimmonForCausalLM`
|
||||
- Persimmon
|
||||
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Qwen2ForCausalLM`
|
||||
- Qwen2
|
||||
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`Qwen2MoeForCausalLM`
|
||||
- Qwen2MoE
|
||||
- :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`StableLmForCausalLM`
|
||||
- StableLM
|
||||
- :code:`stabilityai/stablelm-3b-4e1t`, :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Starcoder2ForCausalLM`
|
||||
- Starcoder2
|
||||
- :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`SolarForCausalLM`
|
||||
- EXAONE-3
|
||||
- Solar Pro
|
||||
- :code:`upstage/solar-pro-preview-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`XverseForCausalLM`
|
||||
- Xverse
|
||||
- XVERSE
|
||||
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
|
||||
.. note::
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
@ -217,7 +265,7 @@ Multimodal Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 25 25 5
|
||||
:widths: 25 25 25 25 5 5
|
||||
:header-rows: 1
|
||||
|
||||
* - Architecture
|
||||
@ -225,86 +273,103 @@ Multimodal Language Models
|
||||
- Modalities
|
||||
- Example HuggingFace Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Blip2ForConditionalGeneration`
|
||||
- BLIP-2
|
||||
- Image\ :sup:`E`
|
||||
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`ChameleonForConditionalGeneration`
|
||||
- Chameleon
|
||||
- Image
|
||||
- :code:`facebook/chameleon-7b` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`FuyuForCausalLM`
|
||||
- Fuyu
|
||||
- Image
|
||||
- :code:`adept/fuyu-8b` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaForConditionalGeneration`
|
||||
- LLaVA-1.5
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaNextForConditionalGeneration`
|
||||
- LLaVA-NeXT
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaNextVideoForConditionalGeneration`
|
||||
- LLaVA-NeXT-Video
|
||||
- Video
|
||||
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaOnevisionForConditionalGeneration`
|
||||
- LLaVA-Onevision
|
||||
- Image\ :sup:`+` / Video
|
||||
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`MiniCPMV`
|
||||
- MiniCPM-V
|
||||
- Image\ :sup:`+`
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MllamaForConditionalGeneration`
|
||||
- Llama 3.2
|
||||
- Image
|
||||
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`PaliGemmaForConditionalGeneration`
|
||||
- PaliGemma
|
||||
- Image\ :sup:`E`
|
||||
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Phi3VForCausalLM`
|
||||
- Phi-3-Vision, Phi-3.5-Vision
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`PixtralForConditionalGeneration`
|
||||
- Pixtral
|
||||
- Image\ :sup:`+`
|
||||
- :code:`mistralai/Pixtral-12B-2409`
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen-VL
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
- Qwen2-VL
|
||||
- Image\ :sup:`E+` / Video\ :sup:`+`
|
||||
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`UltravoxModel`
|
||||
- Ultravox
|
||||
- Audio\ :sup:`E+`
|
||||
- :code:`fixie-ai/ultravox-v0_3`
|
||||
-
|
||||
- ✅︎
|
||||
|
||||
| :sup:`E` Pre-computed embeddings can be inputted for this modality.
|
||||
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
@ -10,8 +10,8 @@ pytest-shard
|
||||
awscli
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio test
|
||||
opencv-python # required for video test
|
||||
librosa # required for audio tests
|
||||
opencv-python # required for video tests
|
||||
peft
|
||||
requests
|
||||
ray[adag]==2.35
|
||||
|
||||
@ -6,6 +6,8 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
to fail.
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -18,79 +20,236 @@ logger = init_logger("test_pipeline_parallel")
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPTestSettings:
|
||||
parallel_setups: List[ParallelSetup]
|
||||
distributed_backends: List[str]
|
||||
trust_remote_code: bool
|
||||
tokenizer_mode: Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 1,
|
||||
pp_base: int = 2,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
):
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fast(
|
||||
*,
|
||||
tp_base: int = 1,
|
||||
pp_base: int = 2,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
):
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
)
|
||||
|
||||
def iter_params(self, model_name: str):
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for distributed_backend in self.distributed_backends:
|
||||
yield (model_name, parallel_setup, distributed_backend,
|
||||
self.trust_remote_code, self.tokenizer_mode)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
GENERATION_MODEL_SETTINGS = {
|
||||
# [DETAILED TESTS]
|
||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||
# [FAST TESTS]
|
||||
# Uses Llama
|
||||
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
||||
# TODO: Test on larger GPU
|
||||
# "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
|
||||
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"bigscience/bloomz-1b1": PPTestSettings.fast(),
|
||||
"THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501
|
||||
# TODO: Test on larger GPU
|
||||
# "databricks/dbrx-instruct": PPTestSettings.fast(),
|
||||
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
|
||||
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
|
||||
"tiiuae/falcon-7b": PPTestSettings.fast(),
|
||||
"google/gemma-2b": PPTestSettings.fast(),
|
||||
"google/gemma-2-9b": PPTestSettings.fast(),
|
||||
"gpt2": PPTestSettings.fast(),
|
||||
"bigcode/starcoder": PPTestSettings.fast(),
|
||||
"EleutherAI/gpt-j-6b": PPTestSettings.fast(),
|
||||
"EleutherAI/pythia-12b": PPTestSettings.fast(),
|
||||
"ibm/PowerLM-3b": PPTestSettings.fast(),
|
||||
"ibm/PowerMoE-3b": PPTestSettings.fast(),
|
||||
# Uses Llama
|
||||
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
|
||||
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"core42/jais-13b-chat": PPTestSettings.fast(),
|
||||
# TODO: Implement PP
|
||||
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
|
||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
|
||||
# Uses Llama
|
||||
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
|
||||
"mosaicml/mpt-7b": PPTestSettings.fast(),
|
||||
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
|
||||
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
|
||||
"allenai/OLMo-1B-hf": PPTestSettings.fast(),
|
||||
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
|
||||
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"microsoft/phi-2": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
# FIXME: https://github.com/vllm-project/vllm/issues/8553
|
||||
# "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
||||
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
|
||||
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
|
||||
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
|
||||
"bigcode/starcoder2-3b": PPTestSettings.fast(),
|
||||
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
|
||||
# FIXME: Cannot load tokenizer in latest transformers version
|
||||
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated]
|
||||
# [FAST TESTS]
|
||||
# Uses Llama
|
||||
# "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
|
||||
}
|
||||
|
||||
MULTIMODAL_MODEL_SETTINGS = {
|
||||
# [FAST TESTS]
|
||||
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
|
||||
"facebook/chameleon-7b": PPTestSettings.fast(),
|
||||
"adept/fuyu-8b": PPTestSettings.fast(),
|
||||
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True),
|
||||
"llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True),
|
||||
# TODO: Implement PP
|
||||
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
|
||||
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
|
||||
}
|
||||
|
||||
CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated]
|
||||
# [FAST TESTS]
|
||||
# TODO: Implement PP
|
||||
# "facebook/bart-base": PPTestSettings.fast(),
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
MODEL_SETTINGS = {
|
||||
**GENERATION_MODEL_SETTINGS,
|
||||
**EMBEDDING_MODEL_SETTINGS,
|
||||
**MULTIMODAL_MODEL_SETTINGS,
|
||||
}
|
||||
|
||||
# You can update this on your local machine to run specific tests
|
||||
TEST_MODELS = [
|
||||
"meta-llama/Meta-Llama-3-8B",
|
||||
"facebook/chameleon-7b",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"microsoft/Phi-3-vision-128k-instruct",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"fixie-ai/ultravox-v0_3",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
|
||||
"MODEL_NAME, DIST_BACKEND"),
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
[
|
||||
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||
# NOTE: InternVL2 multi-node tests are flaky,
|
||||
# use mp backend to skip the multi-node tests
|
||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
|
||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
|
||||
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
|
||||
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
|
||||
# TP only models
|
||||
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
|
||||
params for model_name, settings in MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
|
||||
TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
|
||||
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
|
||||
def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
|
||||
distributed_backend: str, trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str], num_gpus_available):
|
||||
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
|
||||
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} GPUs to run the test")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
|
||||
pp_args = [
|
||||
common_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--pipeline-parallel-size",
|
||||
str(PP_SIZE),
|
||||
"--tensor-parallel-size",
|
||||
str(TP_SIZE),
|
||||
"--distributed-executor-backend",
|
||||
DIST_BACKEND,
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if chunked_prefill:
|
||||
common_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
if trust_remote_code:
|
||||
common_args.append("--trust-remote-code")
|
||||
if tokenizer_mode:
|
||||
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
|
||||
# compare without pipeline parallelism
|
||||
# NOTE: use mp backend for TP
|
||||
# PP tests might involve multiple nodes, and ray might
|
||||
# schedule all workers in a node other than the head node,
|
||||
# which can cause the test to fail.
|
||||
tp_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--tensor-parallel-size",
|
||||
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
if CHUNKED_PREFILL:
|
||||
pp_args.append("--enable-chunked-prefill")
|
||||
tp_args.append("--enable-chunked-prefill")
|
||||
if EAGER_MODE:
|
||||
pp_args.append("--enforce-eager")
|
||||
tp_args.append("--enforce-eager")
|
||||
if TRUST_REMOTE_CODE:
|
||||
pp_args.append("--trust-remote-code")
|
||||
tp_args.append("--trust-remote-code")
|
||||
pp_env = None
|
||||
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
|
||||
and CHUNKED_PREFILL):
|
||||
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
|
||||
and chunked_prefill):
|
||||
# Test Ray ADAG for a subset of the tests
|
||||
pp_env = {
|
||||
"VLLM_USE_RAY_COMPILED_DAG": "1",
|
||||
@ -99,11 +258,35 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
|
||||
}
|
||||
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
||||
# terminate because of aDAG issue.
|
||||
pp_args.append("--disable-frontend-multiprocessing")
|
||||
tp_args.append("--disable-frontend-multiprocessing")
|
||||
common_args.append("--disable-frontend-multiprocessing")
|
||||
else:
|
||||
pp_env = None
|
||||
|
||||
pp_args = [
|
||||
*common_args,
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
|
||||
# compare without pipeline parallelism
|
||||
# NOTE: use mp backend for TP
|
||||
# PP tests might involve multiple nodes, and ray might
|
||||
# schedule all workers in a node other than the head node,
|
||||
# which can cause the test to fail.
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
|
||||
compare_two_settings(model_name, pp_args, tp_args, pp_env)
|
||||
except Exception:
|
||||
if pp_env is None:
|
||||
raise
|
||||
|
||||
@ -1,9 +1,55 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch.cuda
|
||||
|
||||
from vllm.model_executor.models import _MODELS, ModelRegistry
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_cls", _MODELS)
|
||||
def test_registry_imports(model_cls):
|
||||
@pytest.mark.parametrize("model_arch", _MODELS)
|
||||
def test_registry_imports(model_arch):
|
||||
# Ensure all model classes can be imported successfully
|
||||
ModelRegistry.resolve_model_cls([model_cls])
|
||||
ModelRegistry.resolve_model_cls(model_arch)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
|
||||
("LlamaForCausalLM", False, False),
|
||||
("MllamaForConditionalGeneration", True, False),
|
||||
("LlavaForConditionalGeneration", True, True),
|
||||
])
|
||||
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
|
||||
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
|
||||
|
||||
if init_cuda and current_platform.is_cuda_alike():
|
||||
assert not torch.cuda.is_initialized()
|
||||
|
||||
ModelRegistry.resolve_model_cls(model_arch)
|
||||
if not torch.cuda.is_initialized():
|
||||
warnings.warn(
|
||||
"This model no longer initializes CUDA on import. "
|
||||
"Please test using a different one.",
|
||||
stacklevel=2)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
|
||||
("MLPSpeculatorPreTrainedModel", False, False),
|
||||
("DeepseekV2ForCausalLM", True, False),
|
||||
("Qwen2VLForConditionalGeneration", True, True),
|
||||
])
|
||||
def test_registry_is_pp(model_arch, is_pp, init_cuda):
|
||||
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp
|
||||
|
||||
if init_cuda and current_platform.is_cuda_alike():
|
||||
assert not torch.cuda.is_initialized()
|
||||
|
||||
ModelRegistry.resolve_model_cls(model_arch)
|
||||
if not torch.cuda.is_initialized():
|
||||
warnings.warn(
|
||||
"This model no longer initializes CUDA on import. "
|
||||
"Please test using a different one.",
|
||||
stacklevel=2)
|
||||
|
||||
@ -14,7 +14,6 @@ import openai
|
||||
import pytest
|
||||
import requests
|
||||
from openai.types.completion import Completion
|
||||
from transformers import AutoTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from tests.models.utils import TextTextLogprobs
|
||||
@ -24,6 +23,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
|
||||
cuda_device_count_stateless, get_open_port, is_hip)
|
||||
|
||||
@ -181,15 +181,26 @@ def compare_two_settings(model: str,
|
||||
env2: The second set of environment variables to pass to the API server.
|
||||
"""
|
||||
|
||||
trust_remote_code = "--trust-remote-code"
|
||||
if trust_remote_code in arg1 or trust_remote_code in arg2:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model,
|
||||
trust_remote_code=True)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
trust_remote_code = False
|
||||
for args in (arg1, arg2):
|
||||
if "--trust-remote-code" in args:
|
||||
trust_remote_code = True
|
||||
break
|
||||
|
||||
tokenizer_mode = "auto"
|
||||
for args in (arg1, arg2):
|
||||
if "--tokenizer-mode" in args:
|
||||
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
|
||||
break
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
)
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
token_ids = tokenizer(prompt)["input_ids"]
|
||||
token_ids = tokenizer(prompt).input_ids
|
||||
results = []
|
||||
for args, env in ((arg1, env1), (arg2, env2)):
|
||||
with RemoteOpenAIServer(model,
|
||||
|
||||
@ -31,28 +31,7 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
|
||||
|
||||
_PP_SUPPORTED_MODELS = [
|
||||
"AquilaForCausalLM",
|
||||
"AquilaModel",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
"InternLM2ForCausalLM",
|
||||
"InternLMForCausalLM",
|
||||
"InternVLChatModel",
|
||||
"JAISLMHeadModel",
|
||||
"LlamaForCausalLM",
|
||||
"LLaMAForCausalLM",
|
||||
"MistralForCausalLM",
|
||||
"MixtralForCausalLM",
|
||||
"NemotronForCausalLM",
|
||||
"Phi3ForCausalLM",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2MoeForCausalLM",
|
||||
"QWenLMHeadModel",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
]
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@ -228,16 +207,14 @@ class ModelConfig:
|
||||
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
|
||||
) -> Optional["MultiModalConfig"]:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if any(
|
||||
ModelRegistry.is_multimodal_model(arch)
|
||||
for arch in architectures):
|
||||
if ModelRegistry.is_multimodal_model(architectures):
|
||||
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
|
||||
else:
|
||||
if limit_mm_per_prompt:
|
||||
raise ValueError(
|
||||
"limit_mm_per_prompt is only supported for multimodal "
|
||||
"models.")
|
||||
return None
|
||||
|
||||
if limit_mm_per_prompt:
|
||||
raise ValueError("`limit_mm_per_prompt` is only supported for "
|
||||
"multimodal models.")
|
||||
|
||||
return None
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
@ -249,8 +226,7 @@ class ModelConfig:
|
||||
|
||||
def _verify_embedding_mode(self) -> None:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
self.embedding_mode = any(
|
||||
ModelRegistry.is_embedding_model(arch) for arch in architectures)
|
||||
self.embedding_mode = ModelRegistry.is_embedding_model(architectures)
|
||||
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
@ -417,17 +393,17 @@ class ModelConfig:
|
||||
f"({tensor_parallel_size}).")
|
||||
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if not all(arch in _PP_SUPPORTED_MODELS
|
||||
for arch in architectures) and pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is only supported for the following "
|
||||
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
||||
if pipeline_parallel_size > 1:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if not ModelRegistry.is_pp_supported_model(architectures):
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported for this model. "
|
||||
"Supported models implement the `SupportsPP` interface.")
|
||||
|
||||
if pipeline_parallel_size > 1 and self.use_async_output_proc:
|
||||
logger.warning("Async output processor is not supported with "
|
||||
"pipeline parallelism currently. Disabling it.")
|
||||
self.use_async_output_proc = False
|
||||
if self.use_async_output_proc:
|
||||
logger.warning("Async output processor is not supported with "
|
||||
"pipeline parallelism currently. Disabling it.")
|
||||
self.use_async_output_proc = False
|
||||
|
||||
def get_hf_config_sliding_window(self) -> Optional[int]:
|
||||
"""Get the sliding window size, or None if disabled."""
|
||||
|
||||
@ -1,12 +1,18 @@
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import supports_multimodal, supports_pp
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GENERATION_MODELS = {
|
||||
@ -152,19 +158,25 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_model(model_arch: str):
|
||||
module_name, model_cls_name = _MODELS[model_arch]
|
||||
module = importlib.import_module(
|
||||
f"vllm.model_executor.models.{module_name}")
|
||||
return getattr(module, model_cls_name, None)
|
||||
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
|
||||
module_relname, cls_name = _MODELS[model_arch]
|
||||
return f"vllm.model_executor.models.{module_relname}", cls_name
|
||||
|
||||
@staticmethod
|
||||
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch in _OOT_MODELS:
|
||||
return _OOT_MODELS[model_arch]
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch not in _MODELS:
|
||||
return None
|
||||
|
||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch in _OOT_MODELS:
|
||||
return _OOT_MODELS[model_arch]
|
||||
|
||||
if is_hip():
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
@ -175,11 +187,24 @@ class ModelRegistry:
|
||||
"Model architecture %s is partially supported by ROCm: %s",
|
||||
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||
|
||||
return ModelRegistry._get_model(model_arch)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
return ModelRegistry._try_get_model_stateful(model_arch)
|
||||
|
||||
@staticmethod
|
||||
def resolve_model_cls(
|
||||
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
|
||||
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
@ -200,21 +225,99 @@ class ModelRegistry:
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls.__name__)
|
||||
global _OOT_MODELS
|
||||
|
||||
_OOT_MODELS[model_arch] = model_cls
|
||||
|
||||
@staticmethod
|
||||
def is_embedding_model(model_arch: str) -> bool:
|
||||
return model_arch in _EMBEDDING_MODELS
|
||||
@lru_cache(maxsize=128)
|
||||
def _check_stateless(
|
||||
func: Callable[[Type[nn.Module]], bool],
|
||||
model_arch: str,
|
||||
*,
|
||||
default: Optional[bool] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Run a boolean function against a model and return the result.
|
||||
|
||||
If the model is not found, returns the provided default value.
|
||||
|
||||
If the model is not already imported, the function is run inside a
|
||||
subprocess to avoid initializing CUDA for the main program.
|
||||
"""
|
||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
||||
if model is not None:
|
||||
return func(model)
|
||||
|
||||
if model_arch not in _MODELS and default is not None:
|
||||
return default
|
||||
|
||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
|
||||
valid_name_characters = string.ascii_letters + string.digits + "._"
|
||||
if any(s not in valid_name_characters for s in module_name):
|
||||
raise ValueError(f"Unsafe module name detected for {model_arch}")
|
||||
if any(s not in valid_name_characters for s in cls_name):
|
||||
raise ValueError(f"Unsafe class name detected for {model_arch}")
|
||||
if any(s not in valid_name_characters for s in func.__module__):
|
||||
raise ValueError(f"Unsafe module name detected for {func}")
|
||||
if any(s not in valid_name_characters for s in func.__name__):
|
||||
raise ValueError(f"Unsafe class name detected for {func}")
|
||||
|
||||
err_id = uuid.uuid4()
|
||||
|
||||
stmts = ";".join([
|
||||
f"from {module_name} import {cls_name}",
|
||||
f"from {func.__module__} import {func.__name__}",
|
||||
f"assert {func.__name__}({cls_name}), '{err_id}'",
|
||||
])
|
||||
|
||||
result = subprocess.run([sys.executable, "-c", stmts],
|
||||
capture_output=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
err_lines = [line.decode() for line in result.stderr.splitlines()]
|
||||
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
|
||||
err_str = "\n".join(err_lines)
|
||||
raise RuntimeError(
|
||||
"An unexpected error occurred while importing the model in "
|
||||
f"another process. Error log:\n{err_str}")
|
||||
|
||||
return result.returncode == 0
|
||||
|
||||
@staticmethod
|
||||
def is_multimodal_model(model_arch: str) -> bool:
|
||||
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
# TODO: find a way to avoid initializing CUDA prematurely to
|
||||
# use `supports_multimodal` to determine if a model is multimodal
|
||||
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
|
||||
# from vllm.model_executor.models.interfaces import supports_multimodal
|
||||
return model_arch in _MULTIMODAL_MODELS
|
||||
return any(arch in _EMBEDDING_MODELS for arch in architectures)
|
||||
|
||||
@staticmethod
|
||||
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
is_mm = partial(ModelRegistry._check_stateless,
|
||||
supports_multimodal,
|
||||
default=False)
|
||||
|
||||
return any(is_mm(arch) for arch in architectures)
|
||||
|
||||
@staticmethod
|
||||
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
is_pp = partial(ModelRegistry._check_stateless,
|
||||
supports_pp,
|
||||
default=False)
|
||||
|
||||
return any(is_pp(arch) for arch in architectures)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
"""Inference-only Snowflake Arctic model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import (
|
||||
DeepSpeedFPConfig, DeepSpeedFPParameter)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -32,6 +31,10 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.arctic import ArcticConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -364,6 +367,7 @@ class ArcticModel(nn.Module):
|
||||
config: ArcticConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -372,15 +376,16 @@ class ArcticModel(nn.Module):
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.vocab_size)
|
||||
self.layers = nn.ModuleList([
|
||||
ArcticDecoderLayer(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: ArcticDecoderLayer(config, int(
|
||||
prefix.split(".")[-1]), cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -388,17 +393,25 @@ class ArcticModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states, kv_caches[i],
|
||||
hidden_states = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ArcticForCausalLM(nn.Module):
|
||||
class ArcticForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: ArcticConfig,
|
||||
@ -422,6 +435,8 @@ class ArcticForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -430,9 +445,9 @@ class ArcticForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -503,6 +518,8 @@ class ArcticForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -512,6 +529,8 @@ class ArcticForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -522,6 +541,8 @@ class ArcticForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
@ -532,6 +553,8 @@ class ArcticForCausalLM(nn.Module):
|
||||
else:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -27,7 +27,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -35,8 +35,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -45,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
@ -255,7 +256,8 @@ class BaiChuanModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -265,12 +267,16 @@ class BaiChuanModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding, cache_config,
|
||||
quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
|
||||
cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -278,23 +284,34 @@ class BaiChuanModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual,
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"W_pack": ["W_pack"],
|
||||
"gate_up_proj": [
|
||||
@ -335,6 +352,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -343,9 +362,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -394,6 +413,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -402,6 +423,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
@ -413,7 +436,7 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
@ -431,7 +454,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -11,7 +12,7 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -19,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
@ -475,7 +476,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: Blip2Config,
|
||||
@ -508,6 +509,16 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
@ -600,7 +611,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
"""Run forward pass for BLIP-2.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
@ -631,26 +642,32 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
See also:
|
||||
:class:`Blip2ImageInputs`
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -25,15 +25,14 @@ from transformers import BloomConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -41,6 +40,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
@ -222,6 +225,7 @@ class BloomModel(nn.Module):
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -235,13 +239,16 @@ class BloomModel(nn.Module):
|
||||
self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
BloomBlock(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: BloomBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.h")
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -249,22 +256,29 @@ class BloomModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
for i in range(len(self.h)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BloomForCausalLM(nn.Module):
|
||||
class BloomForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -284,6 +298,8 @@ class BloomForCausalLM(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -292,9 +308,9 @@ class BloomForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -321,6 +337,8 @@ class BloomForCausalLM(nn.Module):
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from functools import cached_property
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict)
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -10,7 +10,7 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -33,7 +33,9 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
# These configs are not part of the model config but the preprocessor
|
||||
# and processor files, so we hardcode them in the model file for now.
|
||||
@ -822,6 +824,7 @@ class ChameleonModel(nn.Module):
|
||||
config: ChameleonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -835,14 +838,20 @@ class ChameleonModel(nn.Module):
|
||||
config.vocabulary_map)
|
||||
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
|
||||
else ChameleonSwinDecoderLayer
|
||||
self.layers = nn.ModuleList([
|
||||
decoder_layer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.vqmodel = ChameleonVQVAE(config.vq_config)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@ -865,22 +874,33 @@ class ChameleonModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@ -889,7 +909,8 @@ class ChameleonModel(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
|
||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -914,6 +935,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -956,22 +979,26 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
assert self.model.vqmodel is not None
|
||||
image_tokens = self.model.get_image_tokens(image_input["data"].to(
|
||||
self.config.torch_dtype))
|
||||
image_token_id = self.model.vocabulary_mapping.image_token_id
|
||||
special_image_mask = input_ids == image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask,
|
||||
image_tokens)
|
||||
if image_input is not None:
|
||||
assert self.model.vqmodel is not None
|
||||
image_tokens = self.model.get_image_tokens(
|
||||
image_input["data"].to(self.config.torch_dtype))
|
||||
image_token_id = self.model.vocabulary_mapping.image_token_id
|
||||
special_image_mask = input_ids == image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device,
|
||||
input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask,
|
||||
image_tokens)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -1039,6 +1066,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -1060,11 +1089,15 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if use_default_weight_loading and name in params_dict:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -10,15 +10,14 @@ from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -28,14 +27,16 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -126,7 +127,7 @@ class GLMMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -169,7 +170,7 @@ class GLMBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -240,9 +241,10 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
@ -251,10 +253,11 @@ class GLMTransformer(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList([
|
||||
GLMBlock(config, cache_config, quant_config)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.num_layers,
|
||||
lambda prefix: GLMBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
@ -269,16 +272,16 @@ class GLMTransformer(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i],
|
||||
kv_cache=kv_caches[i - self.start_layer],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
if get_pp_group().is_last_rank and self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@ -288,7 +291,7 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -305,6 +308,9 @@ class ChatGLMModel(nn.Module):
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -312,8 +318,12 @@ class ChatGLMModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
else:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
@ -322,10 +332,13 @@ class ChatGLMModel(nn.Module):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
@ -362,6 +375,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -370,9 +385,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -402,6 +417,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
|
||||
# This file is based on the LLama model definition file in transformers
|
||||
"""PyTorch Cohere model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -29,14 +29,13 @@ from transformers import CohereConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -47,7 +46,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@torch.compile
|
||||
@ -82,7 +83,7 @@ class CohereMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: CohereConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -256,6 +257,7 @@ class CohereModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -265,12 +267,16 @@ class CohereModel(nn.Module):
|
||||
self.org_vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
CohereDecoderLayer(config, cache_config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CohereDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -278,23 +284,34 @@ class CohereModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CohereForCausalLM(nn.Module, SupportsLoRA):
|
||||
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -337,6 +354,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@ -346,9 +365,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -393,6 +412,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -405,6 +426,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
# coding=utf-8
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -24,6 +23,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class DbrxRouter(nn.Module):
|
||||
"""A Router implementation for DBRX that returns logits for each expert
|
||||
@ -296,22 +299,27 @@ class DbrxModel(nn.Module):
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.wte = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList([
|
||||
DbrxBlock(config, cache_config, quant_config)
|
||||
for _ in range(config.n_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.blocks = make_layers(
|
||||
config.n_layers,
|
||||
lambda prefix: DbrxBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.blocks",
|
||||
)
|
||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||
for module in self.modules():
|
||||
if hasattr(module, "bias") and isinstance(module.bias,
|
||||
nn.Parameter):
|
||||
# Remove the bias term in Linear and LayerNorm.
|
||||
module.register_parameter("bias", None)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.d_model))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -319,21 +327,28 @@ class DbrxModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.blocks)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.wte(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DbrxForCausalLM(nn.Module):
|
||||
class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -359,6 +374,8 @@ class DbrxForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -367,9 +384,9 @@ class DbrxForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -401,11 +418,15 @@ class DbrxForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, weight_name)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -29,11 +29,12 @@ import torch
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
from .utils import is_pp_missing_parameter
|
||||
|
||||
|
||||
class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
@ -91,6 +92,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -99,6 +102,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Deepseek model."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -29,7 +29,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -50,6 +49,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
|
||||
@ -329,6 +332,7 @@ class DeepseekModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -338,14 +342,17 @@ class DeepseekModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
DeepseekDecoderLayer(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekDecoderLayer(config,
|
||||
int(prefix.split(".")[-1]),
|
||||
cache_config,
|
||||
quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -353,19 +360,29 @@ class DeepseekModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepseekForCausalLM(nn.Module):
|
||||
class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -384,6 +401,8 @@ class DeepseekForCausalLM(nn.Module):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -392,9 +411,9 @@ class DeepseekForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -439,6 +458,8 @@ class DeepseekForCausalLM(nn.Module):
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -451,6 +472,8 @@ class DeepseekForCausalLM(nn.Module):
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only DeepseekV2 model."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -50,7 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
@ -439,6 +440,9 @@ class DeepseekV2Model(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -447,7 +451,7 @@ class DeepseekV2Model(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@ -472,7 +476,7 @@ class DeepseekV2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepseekV2ForCausalLM(nn.Module):
|
||||
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -492,6 +496,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -500,7 +506,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
@ -38,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -53,8 +52,9 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class ExaoneGatedMLP(nn.Module):
|
||||
@ -354,6 +354,10 @@ class ExaoneModel(nn.Module):
|
||||
else:
|
||||
self.ln_f = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
@ -397,7 +401,7 @@ class ExaoneModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ExaoneForCausalLM(nn.Module, SupportsLoRA):
|
||||
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -477,6 +481,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -506,24 +513,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros(
|
||||
(batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
"residual":
|
||||
torch.zeros(
|
||||
(batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
|
||||
|
||||
@ -333,6 +336,7 @@ class FalconModel(nn.Module):
|
||||
config: FalconConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -347,35 +351,45 @@ class FalconModel(nn.Module):
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
FalconDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: FalconDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.h")
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FalconForCausalLM(nn.Module):
|
||||
class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -403,6 +417,8 @@ class FalconForCausalLM(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -412,12 +428,8 @@ class FalconForCausalLM(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -454,6 +466,8 @@ class FalconForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if "query_key_value" in name:
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
@ -41,8 +41,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: FuyuConfig,
|
||||
@ -242,6 +243,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
self.language_model = PersimmonForCausalLM(config.text_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.language_model.sampler
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -297,23 +304,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
|
||||
else:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.embed_tokens(
|
||||
input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
@ -336,34 +349,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# copy from vllm/model_executor/models/bloom.py
|
||||
# NOTE: Fuyu's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision embeddings
|
||||
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
|
||||
for name, loaded_weight in weights_group["vision_embed_tokens"]:
|
||||
param = vision_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -23,7 +23,7 @@ from transformers import GemmaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -245,6 +246,7 @@ class GemmaModel(nn.Module):
|
||||
config: GemmaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -253,10 +255,11 @@ class GemmaModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
GemmaDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
@ -265,6 +268,9 @@ class GemmaModel(nn.Module):
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = self.config.hidden_size**0.5
|
||||
self.register_buffer("normalizer", torch.tensor(normalizer))
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@ -275,29 +281,38 @@ class GemmaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states *= self.normalizer
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states *= self.normalizer
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -339,6 +354,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.model = GemmaModel(config, cache_config, quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -347,9 +364,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -388,6 +405,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -400,6 +419,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -22,7 +22,7 @@ from transformers import Gemma2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
@ -30,8 +30,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -40,7 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -244,6 +245,7 @@ class Gemma2Model(nn.Module):
|
||||
config: Gemma2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -252,10 +254,11 @@ class Gemma2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
|
||||
-1]), config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
@ -264,6 +267,9 @@ class Gemma2Model(nn.Module):
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = self.config.hidden_size**0.5
|
||||
self.register_buffer("normalizer", torch.tensor(normalizer))
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -271,25 +277,36 @@ class Gemma2Model(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states *= self.normalizer
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states *= self.normalizer
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -338,6 +355,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.vocab_size, soft_cap=config.final_logit_softcapping)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -346,9 +365,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -387,6 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -399,6 +420,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
@ -204,6 +205,9 @@ class GPT2Model(nn.Module):
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h")
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.n_embd))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -234,7 +238,7 @@ class GPT2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2LMHeadModel(nn.Module):
|
||||
class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module):
|
||||
self.config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module):
|
||||
self.embed_dim,
|
||||
org_num_embeddings=config.vocab_size)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList([
|
||||
GPTBigCodeBlock(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.h",
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.n_embd))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(len(self.h)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||
|
||||
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
|
||||
@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -24,14 +24,13 @@ from transformers import GPTJConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class GPTJAttention(nn.Module):
|
||||
|
||||
@ -178,6 +181,7 @@ class GPTJModel(nn.Module):
|
||||
config: GPTJConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -186,11 +190,15 @@ class GPTJModel(nn.Module):
|
||||
config.vocab_size,
|
||||
self.embed_dim,
|
||||
)
|
||||
self.h = nn.ModuleList([
|
||||
GPTJBlock(config, cache_config, quant_config)
|
||||
for _ in range(config.n_layer)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.n_layer,
|
||||
lambda prefix: GPTJBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.h",
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.n_embd))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -198,21 +206,27 @@ class GPTJModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.wte(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTJForCausalLM(nn.Module):
|
||||
class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
|
||||
@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module):
|
||||
config: GPTNeoXConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
GPTNeoXLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTNeoXForCausalLM(nn.Module):
|
||||
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
self.embed_out.weight = self.gpt_neox.embed_in.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.gpt_neox.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
# Models trained using OpenRLHF may include
|
||||
# these tensors in the checkpoint. Skip them.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
|
||||
@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
@ -311,13 +311,13 @@ class GraniteModel(nn.Module):
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
|
||||
hidden_states *= self.config.embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
hidden_states *= self.config.embedding_multiplier
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
@ -337,7 +337,7 @@ class GraniteModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteForCausalLM(nn.Module, SupportsLoRA):
|
||||
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from . import mixtral
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import make_layers
|
||||
|
||||
|
||||
@ -307,7 +307,7 @@ class GraniteMoeModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA):
|
||||
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
||||
Union, overload, runtime_checkable)
|
||||
import inspect
|
||||
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
||||
Protocol, Type, Union, overload, runtime_checkable)
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
|
||||
def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
|
||||
...
|
||||
|
||||
|
||||
@ -32,7 +38,7 @@ class SupportsMultiModal(Protocol):
|
||||
class _SupportsMultiModalType(Protocol):
|
||||
supports_multimodal: Literal[True]
|
||||
|
||||
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
|
||||
def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
|
||||
...
|
||||
|
||||
|
||||
@ -75,7 +81,7 @@ class SupportsLoRA(Protocol):
|
||||
embedding_padding_modules: ClassVar[List[str]]
|
||||
|
||||
# lora_config is None when LoRA is not enabled
|
||||
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||
def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol):
|
||||
embedding_modules: Dict[str, str]
|
||||
embedding_padding_modules: List[str]
|
||||
|
||||
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||
def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@ -145,6 +151,132 @@ def _supports_lora(
|
||||
return isinstance(model, SupportsLoRA)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsPP(Protocol):
|
||||
"""The interface required for all models that support pipeline parallel."""
|
||||
|
||||
supports_pp: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports pipeline parallel.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IntermediateTensors":
|
||||
"""Called when PP rank > 0 for profiling purposes."""
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
intermediate_tensors: Optional["IntermediateTensors"],
|
||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||
"""
|
||||
Accept :class:`IntermediateTensors` when PP rank > 0.
|
||||
|
||||
Return :class:`IntermediateTensors` only for the last PP rank.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
# so we need to treat the class as an instance and use isinstance instead
|
||||
@runtime_checkable
|
||||
class _SupportsPPType(Protocol):
|
||||
supports_pp: Literal[True]
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IntermediateTensors":
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
intermediate_tensors: Optional["IntermediateTensors"],
|
||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_pp(model: object) -> TypeIs[SupportsPP]:
|
||||
...
|
||||
|
||||
|
||||
def supports_pp(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
supports_attributes = _supports_pp_attributes(model)
|
||||
supports_inspect = _supports_pp_inspect(model)
|
||||
|
||||
if supports_attributes and not supports_inspect:
|
||||
logger.warning(
|
||||
"The model (%s) sets `supports_pp=True`, but does not accept "
|
||||
"`intermediate_tensors` in its `forward` method", model)
|
||||
|
||||
if not supports_attributes:
|
||||
pp_attrs = ("make_empty_intermediate_tensors", )
|
||||
missing_attrs = tuple(attr for attr in pp_attrs
|
||||
if not hasattr(model, attr))
|
||||
|
||||
if getattr(model, "supports_pp", False):
|
||||
if missing_attrs:
|
||||
logger.warning(
|
||||
"The model (%s) sets `supports_pp=True`, "
|
||||
"but is missing PP-specific attributes: %s",
|
||||
model,
|
||||
missing_attrs,
|
||||
)
|
||||
else:
|
||||
if not missing_attrs:
|
||||
logger.warning(
|
||||
"The model (%s) contains all PP-specific attributes, "
|
||||
"but does not set `supports_pp=True`.", model)
|
||||
|
||||
return supports_attributes and supports_inspect
|
||||
|
||||
|
||||
def _supports_pp_attributes(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsPPType)
|
||||
|
||||
return isinstance(model, SupportsPP)
|
||||
|
||||
|
||||
def _supports_pp_inspect(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
model_forward = getattr(model, "forward", None)
|
||||
if not callable(model_forward):
|
||||
return False
|
||||
|
||||
forward_params = inspect.signature(model_forward).parameters
|
||||
return "intermediate_tensors" in forward_params
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasInnerState(Protocol):
|
||||
"""The interface required for all models that has inner state."""
|
||||
@ -158,7 +290,7 @@ class HasInnerState(Protocol):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
||||
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
||||
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
|
||||
@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
@ -266,7 +266,7 @@ class InternLM2Model(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: IntermediateTensors = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
@ -297,7 +297,7 @@ class InternLM2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternLM2ForCausalLM(nn.Module):
|
||||
class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: IntermediateTensors,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
|
||||
@ -5,9 +5,9 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict, Union)
|
||||
from functools import cached_property, partial
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -17,7 +17,6 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@ -32,7 +31,7 @@ from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||
|
||||
@ -123,7 +122,7 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
def calculate_num_blocks_wrapper(hf_config: Dict[str, Any],
|
||||
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
@ -183,7 +182,7 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
|
||||
return pixel_values
|
||||
|
||||
|
||||
def image_to_pixel_values_wrapper(hf_config: Dict[str, Any],
|
||||
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
image_size = hf_config.vision_config.image_size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
@ -197,7 +196,7 @@ def image_to_pixel_values_wrapper(hf_config: Dict[str, Any],
|
||||
use_thumbnail=use_thumbnail)
|
||||
|
||||
|
||||
def get_internvl_num_patches(hf_config: Dict[str, Any]):
|
||||
def get_internvl_num_patches(hf_config: PretrainedConfig):
|
||||
vision_config = hf_config.vision_config
|
||||
downsample_ratio = hf_config.downsample_ratio
|
||||
image_size = vision_config.image_size
|
||||
@ -362,7 +361,7 @@ def dummy_data_for_internvl(ctx: InputContext,
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
|
||||
class InternVLChatModel(nn.Module, SupportsMultiModal):
|
||||
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
@ -408,10 +407,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
self.sampler = self.language_model.sampler
|
||||
else:
|
||||
self.sampler = Sampler()
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
@ -515,18 +516,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is not None and get_pp_group().is_first_rank:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.img_context_token_id)
|
||||
) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is not None:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.img_context_token_id)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
@ -33,8 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -43,7 +42,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import JAISConfig
|
||||
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class SwiGLUActivation(nn.Module):
|
||||
@ -244,6 +245,9 @@ class JAISModel(nn.Module):
|
||||
)
|
||||
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.n_embd))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -279,7 +283,7 @@ class JAISModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JAISLMHeadModel(nn.Module):
|
||||
class JAISLMHeadModel(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -304,6 +308,8 @@ class JAISLMHeadModel(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
|
||||
scale=self.output_logits_scale)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -326,16 +332,6 @@ class JAISLMHeadModel(nn.Module):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
|
||||
@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -51,8 +50,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@ -72,12 +72,15 @@ class LlamaMLP(nn.Module):
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -161,12 +164,14 @@ class LlamaAttention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -248,12 +253,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@ -295,12 +298,17 @@ class LlamaModel(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
@ -326,13 +334,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -344,17 +348,10 @@ class LlamaModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
@ -364,7 +361,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
"lm_head": "output_embeddings"
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@ -420,10 +417,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else
|
||||
lora_config.lora_vocab_padding_size),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
@ -436,6 +435,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.sampler = Sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -458,28 +459,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@ -513,7 +497,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -8,10 +8,13 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import is_pp_missing_parameter
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module):
|
||||
class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
"""A model that uses Llama with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the LlamaModel and provides an interface for
|
||||
@ -29,6 +32,8 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
super().__init__()
|
||||
self.model = LlamaModel(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -36,10 +41,12 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model.forward(input_ids, positions, kv_caches,
|
||||
attn_metadata, inputs_embeds)
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
@ -73,6 +80,8 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -81,6 +90,8 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -11,7 +12,7 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -21,7 +22,7 @@ from vllm.utils import is_list_of
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||
input_processor_for_siglip)
|
||||
@ -198,7 +199,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaConfig,
|
||||
@ -220,6 +221,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
@ -315,7 +326,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for LLaVA-1.5.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
@ -351,26 +362,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
See also:
|
||||
:class:`LlavaImageInputs`
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -13,7 +14,7 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -23,7 +24,7 @@ from vllm.utils import is_list_of
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import LlavaMultiModalProjector
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||
@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaNextConfig,
|
||||
@ -300,6 +302,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
@ -308,8 +312,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
expected_dims = (2, )
|
||||
@ -542,7 +553,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for LlaVA-NeXT.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
@ -587,26 +598,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
See also:
|
||||
:class:`LlavaNextImageInputs`
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -12,9 +13,8 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip)
|
||||
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||
@ -267,7 +267,8 @@ class LlavaNextMultiModalProjector(nn.Module):
|
||||
"video", get_max_llava_next_video_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
|
||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaNextVideoConfig,
|
||||
@ -281,13 +282,23 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_video_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
@ -397,34 +408,36 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for LlaVA-NeXT-Video.
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values_videos: Pixels in each frames for each input videos.
|
||||
"""
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
# merge video embeddings into input embeddings
|
||||
if video_input is not None:
|
||||
video_embeddings = self._process_video_pixels(video_input)
|
||||
inputs_embeds = self.language_model \
|
||||
.model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, video_embeddings,
|
||||
self.config.video_token_index)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
if video_input is not None:
|
||||
video_embeddings = self._process_video_pixels(video_input)
|
||||
inputs_embeds = self.language_model \
|
||||
.model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, video_embeddings,
|
||||
self.config.video_token_index)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -17,9 +18,8 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -31,7 +31,7 @@ from vllm.utils import is_list_of
|
||||
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
|
||||
dummy_video_for_clip, get_clip_image_feature_size,
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
|
||||
dummy_video_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
@ -414,7 +414,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
"video", get_max_llava_onevision_video_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
|
||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaOnevisionConfig,
|
||||
@ -434,6 +435,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
expected_dims = (2, )
|
||||
|
||||
@ -805,39 +816,42 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for LlaVA-Onevision.
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values_videos: Pixels in each frames for each input videos.
|
||||
"""
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
# merge video embeddings into input embeddings
|
||||
if modalities:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
if "images" in modalities:
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
if "videos" in modalities:
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_pixels(video_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, video_embeddings,
|
||||
self.config.video_token_index)
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if modalities:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
if "images" in modalities:
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
if "videos" in modalities:
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_pixels(video_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, video_embeddings,
|
||||
self.config.video_token_index)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -30,7 +30,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class MiniCPMMoE(nn.Module):
|
||||
@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -365,15 +367,24 @@ class MiniCPMModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self._init_layers()
|
||||
self._init_layers(prefix, config, cache_config, quant_config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], self.config.hidden_size))
|
||||
|
||||
def _init_layers(self):
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPMDecoderLayer(self.config, self.cache_config,
|
||||
self.quant_config)
|
||||
for _ in range(self.config.num_hidden_layers)
|
||||
])
|
||||
def _init_layers(
|
||||
self,
|
||||
prefix: str,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.embed_tokens(input_ids)
|
||||
@ -387,27 +398,36 @@ class MiniCPMModel(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -470,6 +490,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.logits_processor = LogitsProcessor(unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _init_model(self):
|
||||
self.model = MiniCPMModel(config=self.config,
|
||||
@ -484,7 +506,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
@ -548,6 +570,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -557,6 +581,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
@ -568,6 +594,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -26,6 +26,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
|
||||
MiniCPMForCausalLM,
|
||||
MiniCPMModel)
|
||||
|
||||
from .utils import make_layers
|
||||
|
||||
|
||||
class MiniCPM3Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
@ -199,12 +201,18 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
|
||||
|
||||
class MiniCPM3Model(MiniCPMModel):
|
||||
|
||||
def _init_layers(self):
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPM3DecoderLayer(self.config, self.cache_config,
|
||||
self.quant_config)
|
||||
for _ in range(self.config.num_hidden_layers)
|
||||
])
|
||||
def _init_layers(
|
||||
self,
|
||||
prefix: str,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
|
||||
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
|
||||
|
||||
@ -45,7 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -59,7 +58,8 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import is_pp_missing_parameter
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"llm.lm_head": "lm_head",
|
||||
@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
|
||||
return MultiModalInputs(batch_data)
|
||||
|
||||
|
||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"""
|
||||
The abstract class of MiniCPMV can only be inherited, but cannot be
|
||||
instantiated.
|
||||
@ -374,6 +374,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.llm.make_empty_intermediate_tensors)
|
||||
|
||||
def get_embedding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -498,9 +501,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
if intermediate_tensors is not None:
|
||||
vlm_embeddings = None
|
||||
else:
|
||||
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
|
||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
||||
|
||||
output = self.llm(
|
||||
input_ids=None,
|
||||
@ -557,6 +563,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if is_pp_missing_parameter(
|
||||
name.replace(weight_name, param_name), self):
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -564,6 +573,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
@ -276,6 +276,9 @@ class MixtralModel(nn.Module):
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -284,7 +287,7 @@ class MixtralModel(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@ -306,7 +309,7 @@ class MixtralModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
@ -365,6 +368,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -373,7 +378,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
@ -387,20 +392,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -31,7 +31,7 @@ from transformers import MixtralConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
|
||||
@ -296,6 +299,7 @@ class MixtralModel(nn.Module):
|
||||
config: MixtralConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -305,13 +309,15 @@ class MixtralModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MixtralDecoderLayer(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MixtralDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -319,19 +325,30 @@ class MixtralModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
class MixtralForCausalLM(nn.Module, SupportsPP):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module):
|
||||
if ("block_sparse_moe.experts." in name
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -1,22 +1,21 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -25,6 +24,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
def _get_alibi_slopes(
|
||||
total_num_heads: int,
|
||||
@ -208,6 +211,7 @@ class MPTModel(nn.Module):
|
||||
config: MPTConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
assert config.embedding_fraction == 1.0
|
||||
@ -217,10 +221,10 @@ class MPTModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList([
|
||||
MPTBlock(config, cache_config, quant_config)
|
||||
for _ in range(config.n_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.blocks = make_layers(
|
||||
config.n_layers,
|
||||
lambda prefix: MPTBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.blocks")
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
if config.no_bias:
|
||||
for module in self.modules():
|
||||
@ -228,6 +232,9 @@ class MPTModel(nn.Module):
|
||||
module.bias, nn.Parameter):
|
||||
# Remove the bias term in Linear and LayerNorm.
|
||||
module.register_parameter("bias", None)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.d_model))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -235,21 +242,29 @@ class MPTModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.blocks)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.wte(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MPTForCausalLM(nn.Module):
|
||||
class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -266,6 +281,8 @@ class MPTForCausalLM(nn.Module):
|
||||
self.lm_head = self.transformer.wte
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -274,9 +291,9 @@ class MPTForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -302,6 +319,8 @@ class MPTForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -34,8 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -46,8 +45,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import NemotronConfig
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
# The architecture is pretty similar to Llama, with these changes:
|
||||
# - There is no gate_proj, just up_proj
|
||||
@ -328,6 +328,9 @@ class NemotronModel(nn.Module):
|
||||
eps=config.norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@ -372,7 +375,7 @@ class NemotronModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NemotronForCausalLM(nn.Module, SupportsLoRA):
|
||||
class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -440,6 +443,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.sampler = Sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -470,20 +475,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -29,14 +29,13 @@ from transformers import OlmoConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -45,6 +44,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
"""
|
||||
@ -223,19 +226,24 @@ class OlmoModel(nn.Module):
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
OlmoDecoderLayer(config, cache_config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = nn.LayerNorm(config.hidden_size,
|
||||
elementwise_affine=False,
|
||||
bias=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -243,34 +251,41 @@ class OlmoModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""
|
||||
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
||||
"""
|
||||
# Get embeddings of input.
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if get_pp_group().is_first_rank:
|
||||
# Get embeddings of input.
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
# Apply blocks one-by-one.
|
||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states = self.layers[i](
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[layer_idx],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
# Apply final layer norm.
|
||||
# shape: (batch_size, seq_len or 1, d_model)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OlmoForCausalLM(nn.Module):
|
||||
class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
"""
|
||||
Extremely barebones HF model wrapper.
|
||||
"""
|
||||
@ -294,6 +309,8 @@ class OlmoForCausalLM(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -302,12 +319,13 @@ class OlmoForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@ -358,6 +376,8 @@ class OlmoForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -366,6 +386,8 @@ class OlmoForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -18,15 +18,14 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -36,6 +35,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class OlmoeMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
|
||||
@ -243,6 +246,7 @@ class OlmoeModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -252,34 +256,54 @@ class OlmoeModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
OlmoeDecoderLayer(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OlmoeDecoderLayer(config, int(
|
||||
prefix.split(".")[-1]), cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OlmoeForCausalLM(nn.Module):
|
||||
class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
@ -299,6 +323,9 @@ class OlmoeForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -306,9 +333,9 @@ class OlmoeForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
@ -363,6 +390,9 @@ class OlmoeForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
@ -376,6 +406,9 @@ class OlmoeForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
@ -388,6 +421,9 @@ class OlmoeForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OPT model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -25,15 +25,14 @@ from transformers import OPTConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -41,6 +40,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
|
||||
@ -189,6 +192,7 @@ class OPTDecoder(nn.Module):
|
||||
config: OPTConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -232,10 +236,10 @@ class OPTDecoder(nn.Module):
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
OPTDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OPTDecoderLayer(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@ -246,19 +250,28 @@ class OPTDecoder(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
if self.project_in is not None:
|
||||
inputs_embeds, _ = self.project_in(inputs_embeds)
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
if self.project_in is not None:
|
||||
inputs_embeds, _ = self.project_in(inputs_embeds)
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
if self.project_out is not None:
|
||||
@ -276,6 +289,9 @@ class OPTModel(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config, cache_config, quant_config)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder.get_input_embeddings(input_ids)
|
||||
@ -286,20 +302,22 @@ class OPTModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.decoder(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: OPTConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -314,6 +332,8 @@ class OPTForCausalLM(nn.Module):
|
||||
config.word_embed_proj_dim)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -322,9 +342,9 @@ class OPTForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -365,6 +385,8 @@ class OPTForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -373,6 +395,8 @@ class OPTForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
# Copyright (c) OrionStar Inc.
|
||||
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
|
||||
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -12,14 +12,13 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -28,6 +27,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class OrionMLP(nn.Module):
|
||||
|
||||
@ -210,6 +213,7 @@ class OrionModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -219,11 +223,18 @@ class OrionModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
OrionDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OrionDecoderLayer(
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -231,23 +242,34 @@ class OrionModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OrionForCausalLM(nn.Module):
|
||||
class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -266,6 +288,8 @@ class OrionForCausalLM(nn.Module):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -274,9 +298,9 @@ class OrionForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -321,6 +345,8 @@ class OrionForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -329,6 +355,8 @@ class OrionForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -9,9 +9,8 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.gemma import GemmaForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -19,7 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
|
||||
@ -129,7 +128,8 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: PaliGemmaConfig,
|
||||
@ -149,12 +149,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.quant_config = quant_config
|
||||
self.language_model = GemmaForCausalLM(config.text_config,
|
||||
cache_config, quant_config)
|
||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.text_config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = Sampler()
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.language_model.sampler
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
@ -239,32 +242,36 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object) -> SamplerOutput:
|
||||
|
||||
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if parsed_image_input is not None:
|
||||
vision_embeddings = self._process_image_input(parsed_image_input)
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
|
||||
vision_embeddings = vision_embeddings * (self.config.hidden_size**
|
||||
-0.5)
|
||||
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
|
||||
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if parsed_image_input is not None:
|
||||
vision_embeddings = self._process_image_input(
|
||||
parsed_image_input)
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
|
||||
vision_embeddings = vision_embeddings * (
|
||||
self.config.hidden_size**-0.5)
|
||||
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only persimmon model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -28,14 +28,13 @@ from transformers import PersimmonConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -44,6 +43,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class PersimmonMLP(nn.Module):
|
||||
|
||||
@ -211,20 +214,23 @@ class PersimmonModel(nn.Module):
|
||||
def __init__(self,
|
||||
config: PersimmonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
PersimmonDecoderLayer(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: PersimmonDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -232,24 +238,31 @@ class PersimmonModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
hidden_states = self.layers[i](
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PersimmonForCausalLM(nn.Module):
|
||||
class PersimmonForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: PersimmonConfig,
|
||||
@ -266,6 +279,8 @@ class PersimmonForCausalLM(nn.Module):
|
||||
bias=False)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -281,6 +296,7 @@ class PersimmonForCausalLM(nn.Module):
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
@ -312,6 +328,8 @@ class PersimmonForCausalLM(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
|
||||
@ -35,7 +35,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -43,14 +43,13 @@ from transformers import PhiConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -59,7 +58,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
@ -196,18 +197,22 @@ class PhiModel(nn.Module):
|
||||
def __init__(self,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
PhiLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: PhiLayer(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -215,23 +220,31 @@ class PhiModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -274,6 +287,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -282,9 +297,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -325,6 +340,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -335,6 +352,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
continue
|
||||
# pylint: disable=E1136
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -7,14 +7,13 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -23,6 +22,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
def load_column_parallel_weight(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
@ -301,20 +304,25 @@ class Phi3SmallModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
||||
self.layers = nn.ModuleList([
|
||||
Phi3SmallDecoderLayer(config, layer_idx, cache_config,
|
||||
quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Phi3SmallDecoderLayer(config,
|
||||
int(prefix.split('.')[-1]),
|
||||
cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
@ -327,30 +335,37 @@ class Phi3SmallModel(nn.Module):
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata = None,
|
||||
):
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
if (self.mup_embedding_multiplier is not None
|
||||
and self.mup_embedding_multiplier > 0.0):
|
||||
hidden_states = hidden_states * self.mup_embedding_multiplier
|
||||
for i in range(len(self.layers)):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
if (self.mup_embedding_multiplier is not None
|
||||
and self.mup_embedding_multiplier > 0.0):
|
||||
hidden_states = hidden_states * self.mup_embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Phi3SmallForCausalLM(nn.Module):
|
||||
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
@ -372,6 +387,8 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# tokens in tiktoken but not used
|
||||
if hasattr(config, 'dummy_token_indices'):
|
||||
@ -419,12 +436,13 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
output_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
output_hidden_states = output_hidden_states
|
||||
return output_hidden_states
|
||||
@ -447,6 +465,8 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
continue
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
@ -29,13 +29,11 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||
@ -43,8 +41,9 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -295,6 +294,37 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
dim=2).reshape(num_images, -1, hid_dim)
|
||||
return image_features_hd_newline
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.img_processor.load_weights(weights_group["img_processor"])
|
||||
|
||||
# load glb_GN
|
||||
for name, loaded_weight in weights_group["glb_GN"]:
|
||||
assert name == ""
|
||||
param = self.glb_GN
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load sub_GN
|
||||
for name, loaded_weight in weights_group["sub_GN"]:
|
||||
assert name == ""
|
||||
param = self.sub_GN
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.img_projection.named_parameters())
|
||||
for name, loaded_weight in weights_group["img_projection"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
@ -508,7 +538,7 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
@ -521,17 +551,21 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
self.multimodal_config = multimodal_config
|
||||
self.image_token_id = _IMAGE_TOKEN_ID
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.language_model = LlamaForCausalLM(config, cache_config,
|
||||
quant_config)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
expected_dims = (2, )
|
||||
@ -631,24 +665,29 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -657,66 +696,38 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
hf_to_vllm_mapping = {
|
||||
"model.vision_embed_tokens.": "vision_embed_tokens.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.": "language_model.model.",
|
||||
}
|
||||
|
||||
# TODO(ChristopherCho): This is a temporary fix to load
|
||||
# the vision weights with CLIPVisionModel.load_weights()
|
||||
vision_weights = []
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
# Skip loading the img_processor weights since they are
|
||||
# loaded separately.
|
||||
if "vision_embed_tokens.img_processor" in name:
|
||||
vision_weights.append((name, loaded_weight))
|
||||
continue
|
||||
def hf_to_vllm_name(key: str) -> str:
|
||||
for hf_name, vllm_name in hf_to_vllm_mapping.items():
|
||||
if key.startswith(hf_name):
|
||||
return key.replace(hf_name, vllm_name, 1)
|
||||
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
return key
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
vllm_weights = {hf_to_vllm_name(k): v for k, v in weights}
|
||||
|
||||
# We use regex to extract the sub-module name
|
||||
# from "model.vision_embed_tokens.img_processor.*"
|
||||
vision_weights = [
|
||||
(re.search(r"vision_embed_tokens\.img_processor\.(.*)",
|
||||
n).group(1), w) for n, w in vision_weights
|
||||
]
|
||||
self.vision_embed_tokens.img_processor.load_weights(vision_weights)
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(vllm_weights.items())
|
||||
|
||||
# load vision embeddings and encoder
|
||||
self.vision_embed_tokens.load_weights(
|
||||
weights_group["vision_embed_tokens"])
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only PhiMoE model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -29,7 +29,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
@ -46,7 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class PhiMoEConfig(PretrainedConfig):
|
||||
@ -435,6 +437,7 @@ class PhiMoEModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -448,33 +451,56 @@ class PhiMoEModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: PhiMoEDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
elementwise_affine=True)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
|
||||
class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
@ -537,6 +563,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -544,9 +573,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
@ -589,6 +618,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -599,6 +631,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
@ -613,6 +648,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from itertools import tee
|
||||
from typing import Iterable, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
@ -16,7 +17,7 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -25,7 +26,7 @@ from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import init_vllm_registered_model
|
||||
|
||||
|
||||
@ -126,7 +127,8 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
@ -155,6 +157,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.vision_language_adapter = VisionLanguageAdapter(
|
||||
self.vision_args, dim=config.text_config.hidden_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -163,32 +175,36 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for pixtral.
|
||||
|
||||
TODO
|
||||
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_args.image_token_id)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_args.image_token_id)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -31,15 +31,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@ -47,7 +45,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -568,6 +568,9 @@ class QWenModel(nn.Module):
|
||||
lambda prefix: QWenBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.h")
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.visual = VisionTransformer(**config.visual,
|
||||
quant_config=quant_config) if hasattr(
|
||||
config, "visual") else None
|
||||
@ -580,7 +583,7 @@ class QWenModel(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
pixel_values: Optional[QwenImageInputs],
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
img_pos = None
|
||||
# If pixel / visual embeddings are provided, this is a visual model
|
||||
if pixel_values is not None and self.visual is not None:
|
||||
@ -860,7 +863,7 @@ def dummy_data_for_qwen(
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
||||
class QWenLMHeadModel(nn.Module, SupportsMultiModal):
|
||||
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -881,6 +884,8 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
|
||||
self.lm_head.weight = self.transformer.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def _get_image_input_type(
|
||||
self,
|
||||
@ -912,33 +917,26 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
|
||||
)
|
||||
return None
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pixel_values = self._get_image_input_type(pixel_values)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
pixel_values = None
|
||||
else:
|
||||
pixel_values = self._get_image_input_type(pixel_values)
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
pixel_values)
|
||||
return hidden_states
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -48,8 +47,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
@ -253,6 +253,9 @@ class Qwen2Model(nn.Module):
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
@ -269,7 +272,7 @@ class Qwen2Model(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@ -298,7 +301,7 @@ class Qwen2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -357,6 +360,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -365,7 +370,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
@ -379,20 +384,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -42,8 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -53,7 +52,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
@ -338,6 +339,9 @@ class Qwen2MoeModel(nn.Module):
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -346,7 +350,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@ -368,7 +372,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2MoeForCausalLM(nn.Module):
|
||||
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
@ -389,6 +393,8 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -397,7 +403,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
@ -411,20 +417,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
|
||||
@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalInputs)
|
||||
@ -68,6 +67,7 @@ from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory)
|
||||
|
||||
@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
|
||||
"video", get_max_qwen2_vl_video_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: Qwen2VLConfig,
|
||||
@ -1027,7 +1028,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for Qwen2-VL.
|
||||
|
||||
Args:
|
||||
@ -1047,41 +1048,43 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
|
||||
`None` if no videos are passed.
|
||||
"""
|
||||
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
if (image_input is None
|
||||
and video_input is None) or not get_pp_group().is_first_rank:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
if getattr(self.config, "rope_scaling", {}).get("type",
|
||||
None) == "mrope":
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}")
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if image_input is None and video_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
rope_scaling = getattr(self.config, "rope_scaling", {})
|
||||
if rope_scaling.get("type", None) == "mrope":
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}")
|
||||
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
video_embeds,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
video_embeds,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -246,7 +246,7 @@ class SiglipParallelAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -312,7 +312,7 @@ class SiglipMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
@ -37,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -47,14 +47,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.model_executor.models.utils import (PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_layers)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class SolarMLP(nn.Module):
|
||||
|
||||
@ -98,7 +98,7 @@ class SolarAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
@ -187,7 +187,7 @@ class SolarDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -267,7 +267,7 @@ class SolarModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
@ -304,6 +304,10 @@ class SolarModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
@ -368,7 +372,7 @@ class SolarModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SolarForCausalLM(nn.Module, SupportsLoRA):
|
||||
class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -406,7 +410,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
@ -448,6 +452,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA):
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -474,24 +481,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros(
|
||||
(batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
"residual":
|
||||
torch.zeros(
|
||||
(batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
||||
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
||||
model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -27,14 +27,13 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -43,6 +42,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
|
||||
@ -194,19 +197,25 @@ class StableLMEpochModel(nn.Module):
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '') -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
StablelmDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: StablelmDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
getattr(config, "layer_norm_eps", 1e-05))
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -214,21 +223,28 @@ class StableLMEpochModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StablelmForCausalLM(nn.Module):
|
||||
class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -247,6 +263,8 @@ class StablelmForCausalLM(nn.Module):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -255,9 +273,9 @@ class StablelmForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -302,6 +320,8 @@ class StablelmForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -310,6 +330,8 @@ class StablelmForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Starcoder2 model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -26,14 +26,13 @@ from transformers import Starcoder2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -42,6 +41,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class Starcoder2Attention(nn.Module):
|
||||
|
||||
@ -195,7 +198,8 @@ class Starcoder2Model(nn.Module):
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -204,13 +208,16 @@ class Starcoder2Model(nn.Module):
|
||||
# TODO: consider padding_idx (currently removed)
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
Starcoder2DecoderLayer(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Starcoder2DecoderLayer(
|
||||
config, cache_config, quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -218,17 +225,25 @@ class Starcoder2Model(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states, kv_caches[i],
|
||||
hidden_states = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Starcoder2ForCausalLM(nn.Module):
|
||||
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
@ -255,6 +270,8 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -263,9 +280,9 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -302,6 +319,8 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -309,6 +328,8 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
else:
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import math
|
||||
from array import array
|
||||
from functools import lru_cache
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union, cast)
|
||||
|
||||
@ -22,12 +22,10 @@ from vllm.inputs.data import LLMInputs
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import (flatten_bn,
|
||||
group_weights_with_prefix,
|
||||
init_vllm_registered_model,
|
||||
@ -37,9 +35,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs, NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||
|
||||
@ -323,7 +324,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
"audio", get_ultravox_max_audio_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
|
||||
class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: UltravoxConfig,
|
||||
@ -353,6 +354,16 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
revision=None,
|
||||
prefix="language_model."))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _audio_features_to_embeddings(
|
||||
self, input_features: torch.Tensor) -> torch.Tensor:
|
||||
audio_input = input_features.to(self.audio_tower.dtype)
|
||||
@ -425,7 +436,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[torch.Tensor],
|
||||
**kwargs) -> SamplerOutput:
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for Ultravox
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
@ -438,18 +449,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
Args:
|
||||
audio_features: A batch of audio inputs [B, N, 80, M].
|
||||
"""
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is not None:
|
||||
audio_embeddings = self._process_audio_input(audio_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, audio_embeddings,
|
||||
_AUDIO_PLACEHOLDER_TOKEN)
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is not None:
|
||||
audio_embeddings = self._process_audio_input(audio_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, audio_embeddings,
|
||||
_AUDIO_PLACEHOLDER_TOKEN)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -24,7 +24,7 @@ class WeightsGroup(UserDict):
|
||||
when attempting to access a weight component that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> int:
|
||||
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
|
||||
|
||||
def group_weights_with_prefix(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
|
||||
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
|
||||
"""
|
||||
Helper function to group weights with prefix
|
||||
"""
|
||||
@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prefix="",
|
||||
) -> torch.nn.Module:
|
||||
def __call__(self, prefix: str) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
|
||||
@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
key: torch.zeros((batch_size, hidden_size),
|
||||
dtype=dtype,
|
||||
@ -342,8 +340,14 @@ class LLMWrapper(nn.Module):
|
||||
self.model_name = name
|
||||
setattr(self, name, llm)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name)(*args, **kwargs)
|
||||
def __getattr__(self, key: str):
|
||||
llm = super().__getattr__(self.model_name)
|
||||
if key == self.model_name:
|
||||
return llm
|
||||
|
||||
def embed_tokens(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)
|
||||
return getattr(llm, key)
|
||||
|
||||
# We need to explicitly override this
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
llm = super().__getattr__(self.model_name)
|
||||
return llm(*args, **kwargs)
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Xverse model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -28,15 +28,14 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -45,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@ -227,6 +228,7 @@ class XverseModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -240,11 +242,16 @@ class XverseModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
XverseDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: XverseDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -252,23 +259,32 @@ class XverseModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -317,6 +333,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -325,9 +343,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -368,6 +386,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -376,6 +396,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user