diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 0b775851c057..55530d0da8d7 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -389,7 +389,8 @@ steps:
- pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
- - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
+ - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
+ - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- label: Language Models Test (Standard) # 32min
#mirror_hardwares: [amd]
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index f1803b39c883..afe0b53077a7 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -553,6 +553,9 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
+ if not hasattr(config, "hidden_size"):
+ # Support for llama4
+ config = config.text_config
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 8b568de7c81c..2fb969ea85f1 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -24,7 +24,7 @@ vLLM also supports model implementations that are available in Transformers. Thi
To check if the modeling backend is Transformers, you can simply do this:
-```python
+```python
from vllm import LLM
llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
@@ -55,7 +55,7 @@ If your model is neither supported natively by vLLM or Transformers, you can sti
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
-```python
+```python
from vllm import LLM
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
@@ -850,6 +850,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
+- * `Llama4ForConditionalGeneration`
+ * Llama-4-17B-Omni-Instruct
+ * T + I+
+ * `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
+ *
+ * ✅︎
+ * ✅︎
- * `LlavaForConditionalGeneration`
* LLaVA-1.5
* T + IE+
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index 840892ea0701..f33efbab955e 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=4096,
- max_num_seqs=5,
+ max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index c1115708505a..61d53dda1c47 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
)
+def run_llama4(questions: list[str], modality: str):
+ assert modality == "image"
+
+ model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=8192,
+ max_num_seqs=4,
+ tensor_parallel_size=8,
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
+ gpu_memory_utilization=0.4,
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ messages = [[{
+ "role":
+ "user",
+ "content": [{
+ "type": "image"
+ }, {
+ "type": "text",
+ "text": f"{question}"
+ }]
+ }] for question in questions]
+ prompts = tokenizer.apply_chat_template(messages,
+ add_generation_prompt=True,
+ tokenize=False)
+ stop_token_ids = None
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ stop_token_ids=stop_token_ids,
+ )
+
+
# Molmo
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -907,6 +943,7 @@ model_example_map = {
"minicpmv": run_minicpmv,
"mistral3": run_mistral3,
"mllama": run_mllama,
+ "llama4": run_llama4,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
"paligemma": run_paligemma,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 39951e5e89c4..e03ebe485eaa 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
)
+def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
+ model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=8192,
+ max_num_seqs=4,
+ tensor_parallel_size=8,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ )
+
+ placeholders = [{"type": "image", "image": url} for url in image_urls]
+ messages = [{
+ "role":
+ "user",
+ "content": [
+ *placeholders,
+ {
+ "type": "text",
+ "text": question
+ },
+ ],
+ }]
+
+ processor = AutoProcessor.from_pretrained(model_name)
+
+ prompt = processor.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ image_data=[fetch_image(url) for url in image_urls],
+ )
+
+
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@@ -567,6 +604,7 @@ model_example_map = {
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
+ "llama4": load_llama4,
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
diff --git a/requirements/common.txt b/requirements/common.txt
index 7365a5b46a30..24a1e6d67ac2 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -6,7 +6,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
-transformers >= 4.50.3
+transformers >= 4.51.0
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
diff --git a/requirements/test.in b/requirements/test.in
index 364747e9c08f..ac7f451e96a8 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
-transformers==4.50.3
+transformers==4.51.0
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
bitsandbytes>=0.45.3
diff --git a/requirements/test.txt b/requirements/test.txt
index 236b8be32805..39d6ed1acff0 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -645,7 +645,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
-transformers==4.50.3
+transformers==4.51.0
# via
# -r requirements/test.in
# genai-perf
diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py
index 83ece5d22bfb..a843e41aa26e 100644
--- a/tests/models/decoder_only/audio_language/test_ultravox.py
+++ b/tests/models/decoder_only/audio_language/test_ultravox.py
@@ -12,6 +12,7 @@ from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer
+from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
@@ -55,7 +56,10 @@ def server(request, audio_assets):
for key, value in request.param.items()
]
- with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+ with RemoteOpenAIServer(MODEL_NAME,
+ args,
+ env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
+ "30"}) as remote_server:
yield remote_server
@@ -106,6 +110,10 @@ def run_test(
**kwargs,
):
"""Inference result should be the same between hf and vllm."""
+ model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
+ model_info.check_available_online(on_fail="skip")
+ model_info.check_transformers_version(on_fail="skip")
+
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
@@ -156,6 +164,10 @@ def run_multi_audio_test(
num_logprobs: int,
**kwargs,
):
+ model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
+ model_info.check_available_online(on_fail="skip")
+ model_info.check_transformers_version(on_fail="skip")
+
with vllm_runner(model,
dtype=dtype,
enforce_eager=True,
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 3b34f012f626..9d9e8278af4f 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -160,17 +160,32 @@ VLM_TEST_SETTINGS = {
),
"aya_vision": VLMTestInfo(
models=["CohereForAI/aya-vision-8b"],
- test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "What's the content in the center of the image?", # noqa: E501
"cherry_blossom": "What is the season?", # noqa: E501
}),
multi_image_prompt="Describe the two images in detail.", # noqa: E501
- max_model_len=8192,
+ max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
- vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
+ vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}},
+ ),
+ "aya_vision-multi_image": VLMTestInfo(
+ models=["CohereForAI/aya-vision-8b"],
+ test_type=(VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501
+ single_image_prompts=IMAGE_ASSETS.prompts({
+ "stop_sign": "What's the content in the center of the image?", # noqa: E501
+ "cherry_blossom": "What is the season?", # noqa: E501
+ }),
+ multi_image_prompt="Describe the two images in detail.", # noqa: E501
+ max_model_len=4096,
+ max_num_seqs=2,
+ auto_cls=AutoModelForImageTextToText,
+ vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}},
+ marks=[large_gpu_mark(min_gb=32)],
),
"blip2": VLMTestInfo(
# TODO: Change back to 2.7b once head_dim = 80 is supported
@@ -303,6 +318,22 @@ VLM_TEST_SETTINGS = {
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
+ "llama4": VLMTestInfo(
+ models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
+ prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
+ img_idx_to_prompt=lambda _: "<|image|>",
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ distributed_executor_backend="mp",
+ image_size_factors=[(.25, 0.5, 1.0)],
+ hf_model_kwargs={"device_map": "auto"},
+ max_model_len=8192,
+ max_num_seqs=4,
+ dtype="bfloat16",
+ auto_cls=AutoModelForImageTextToText,
+ tensor_parallel_size=8,
+ vllm_runner_kwargs={"gpu_memory_utilization": 0.8},
+ marks=multi_gpu_marks(num_gpus=8),
+ ),
"llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS),
diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py
index 53b183b2735e..237d499d8f6a 100644
--- a/tests/models/decoder_only/vision_language/test_phi3v.py
+++ b/tests/models/decoder_only/vision_language/test_phi3v.py
@@ -5,7 +5,9 @@ import re
from typing import Optional
import pytest
+from packaging.version import Version
from transformers import AutoTokenizer
+from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.multimodal.image import rescale_image_size
from vllm.platforms import current_platform
@@ -81,6 +83,13 @@ def run_test(
from transformers import AutoImageProcessor # noqa: F401
from transformers import AutoProcessor # noqa: F401
+ # Once the model repo is updated to 4.49, we should be able to run the
+ # test in `test_models.py` without the above workaround
+ if Version(TRANSFORMERS_VERSION) >= Version("4.49"):
+ pytest.skip(f"`transformers=={TRANSFORMERS_VERSION}` installed, "
+ "but `transformers<=4.49` is required to run this model. "
+ "Reason: Cannot run HF implementation")
+
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py
index ee619d8d80c4..2f14a8ea321f 100644
--- a/tests/models/decoder_only/vision_language/test_pixtral.py
+++ b/tests/models/decoder_only/vision_language/test_pixtral.py
@@ -176,6 +176,8 @@ def test_chat(
model,
dtype=dtype,
tokenizer_mode="mistral",
+ load_format="mistral",
+ config_format="mistral",
max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index fdcd7a9e1738..35334ef13b7a 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -257,6 +257,7 @@ def _test_processing_correctness_mistral(
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
+ "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py
new file mode 100644
index 000000000000..7ec7c8002974
--- /dev/null
+++ b/tests/models/multimodal/processing/test_llama4.py
@@ -0,0 +1,99 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Llama4's multimodal preprocessing kwargs."""
+
+import pytest
+
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.transformers_utils.tokenizer import encode_tokens
+
+from ....conftest import _ImageAssets
+from ...utils import build_model_context
+
+
+@pytest.mark.parametrize("model_id",
+ ["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
+@pytest.mark.parametrize("mm_processor_kwargs", [{}])
+@pytest.mark.parametrize("num_imgs", [1, 5])
+@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
+@pytest.mark.parametrize("tokenized_prompt", [True, False])
+def test_processor_override(
+ image_assets: _ImageAssets,
+ model_id: str,
+ mm_processor_kwargs: dict,
+ num_imgs: int,
+ disable_mm_preprocessor_cache: bool,
+ tokenized_prompt: bool,
+):
+ """Ensure llama4 processor works properly."""
+ ctx = build_model_context(
+ model_id,
+ mm_processor_kwargs=mm_processor_kwargs,
+ limit_mm_per_prompt={"image": num_imgs},
+ disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
+ )
+ processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
+ config = processor.info.get_hf_config()
+ tokenizer = processor.info.get_tokenizer()
+ hf_processor = processor.info.get_hf_processor()
+ vocab = tokenizer.get_vocab()
+
+ prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \
+ + "<|image|>" * num_imgs \
+ + "<|eot|><|header_start|>assistant<|header_end|>"
+ mm_data = {
+ "image": [
+ image_assets[(i % len(image_assets))].pil_image
+ for i in range(num_imgs)
+ ]
+ }
+ if tokenized_prompt:
+ prompt = encode_tokens(tokenizer, prompt)
+
+ processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
+ mm_kwargs = processed_inputs["mm_kwargs"]
+
+ # place holder replacements
+ prompt_token_ids = processed_inputs["prompt_token_ids"]
+ assert prompt_token_ids.count(config.boi_token_index) == num_imgs
+ assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
+ assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
+ aspect_ratios = mm_kwargs["aspect_ratios"]
+ num_x_separators = num_y_separators = 0
+ for tiles_y, tiles_x in aspect_ratios:
+ if tiles_x * tiles_y > 1:
+ num_x_separators += (tiles_x - 1) * tiles_y
+ num_y_separators += tiles_y
+ assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \
+ == num_x_separators
+ assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \
+ == num_y_separators
+
+ # image token offsets
+ img_locs = processed_inputs["mm_placeholders"].get("image", [])
+ assert len(img_locs) == num_imgs
+ assert [img_loc["offset"] for img_loc in img_locs] == \
+ [i for i, v in enumerate(prompt_token_ids) \
+ if v == config.boi_token_index]
+
+ # patch sizes and masks
+ assert prompt_token_ids.count(config.image_token_index) \
+ == sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"])
+ patch_token_id = vocab[hf_processor.img_patch_token]
+ num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
+ mm_counts = {"image": num_imgs}
+ assert num_patches / num_imgs <= \
+ processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
+ num_patches_per_chunk = processor.info.get_patch_per_chunk(
+ config.vision_config)
+ assert prompt_token_ids.count(config.image_token_index) \
+ == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
+ assert mm_kwargs["pixel_values"].shape[0] \
+ == mm_kwargs["patches_per_image"].sum()
+
+ for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"],
+ mm_kwargs["aspect_ratios"]):
+ assert embed_is_patch.shape[0] == \
+ len(tokenizer.encode(
+ hf_processor._prompt_split_image(
+ aspect_ratio, num_patches_per_chunk),
+ add_special_tokens=False))
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 574b8d9e1308..e61cbc5756f6 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -287,12 +287,16 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
- extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501
+ extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
+ max_transformers_version="4.48", # noqa: E501
+ transformers_version_reason="HF model is not compatible."), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
+ "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
+ min_transformers_version="4.51"),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py
index 58705637ce94..cd2b8f00d521 100644
--- a/tests/models/test_initialization.py
+++ b/tests/models/test_initialization.py
@@ -7,6 +7,8 @@ from transformers import PretrainedConfig
from vllm import LLM
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
+from vllm.utils import GiB_bytes
+from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore
from .registry import HF_EXAMPLE_MODELS
@@ -42,14 +44,21 @@ def test_can_initialize(model_arch):
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0
- def _initalize_kv_caches_v1(self, vllm_config):
- # gpu_blocks (> 0), cpu_blocks
- return 1, 0
+ def _initialize_kv_caches_v1(self, vllm_config):
+ kv_cache_specs = self.model_executor.get_kv_cache_specs()
+ scheduler_kv_cache_config = get_kv_cache_config(
+ vllm_config,
+ kv_cache_specs[0],
+ 20 * GiB_bytes,
+ )
+
+ # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
+ return 1, 0, scheduler_kv_cache_config
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
- _initalize_kv_caches_v1)):
+ _initialize_kv_caches_v1)):
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
diff --git a/vllm/config.py b/vllm/config.py
index d6f931ca1a43..c232f0f5e223 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -358,6 +358,8 @@ class ModelConfig:
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
+ self.attention_chunk_size = getattr(self.hf_text_config,
+ "attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=hf_token, revision=revision)
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 9041b92a5de1..d7e8d045108e 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -500,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat", "idefics3"):
return ""
- if model_type == "mllama":
+ if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json
new file mode 100644
index 000000000000..f10e39482e58
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json
@@ -0,0 +1,200 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 1
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 1
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 1
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 1
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "32": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 1
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 1
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "96": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "256": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 8,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0,
+ "matrix_instr_nonkdim": 16,
+ "kpack": 2
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index a17afd1b357e..d6a27aa0ddc4 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
+ apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@@ -96,8 +97,14 @@ def cutlass_moe_fp8(
n = w2_q.size(1)
topk = topk_ids.size(1)
+
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
+ if apply_router_weight_on_input:
+ assert topk == 1, \
+ "apply_router_weight_on_input is only implemented for topk=1"
+ # TODO: this only works for topK=1, will need to update for topK>1
+ a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
@@ -139,6 +146,8 @@ def cutlass_moe_fp8(
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
-
- return (c2[c_map].view(m, topk, k) *
- topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
+ # Gather tokens
+ c2 = c2[c_map].view(m, topk, k)
+ if not apply_router_weight_on_input:
+ c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
+ return c2.sum(dim=1)
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 0817879c4d57..4ab99acb742f 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
- activation, use_fp8_w8a8, use_int8_w8a16,
- use_int4_w4a16, global_num_experts, expert_map,
- w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
- block_shape)
+ activation, apply_router_weight_on_input, use_fp8_w8a8,
+ use_int8_w8a16, use_int4_w4a16, global_num_experts,
+ expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
+ a2_scale, block_shape)
def inplace_fused_experts_fake(
@@ -980,6 +981,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@@ -1010,6 +1012,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@@ -1023,10 +1026,11 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
- False, activation, use_fp8_w8a8, use_int8_w8a16,
- use_int4_w4a16, global_num_experts, expert_map,
- w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
- a2_scale, block_shape)
+ False, activation, apply_router_weight_on_input,
+ use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
+ global_num_experts, expert_map, w1_scale,
+ w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
+ block_shape)
def outplace_fused_experts_fake(
@@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
+ assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
@@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
+ apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
@@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
- False,
+ apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
@@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
- True,
+ not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 661fb52bbee2..0e35d8a80988 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -65,7 +65,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
- e_score_correction_bias: Optional[torch.Tensor] = None
+ e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
+ activation: str = "silu",
) -> torch.Tensor:
raise NotImplementedError
@@ -156,22 +158,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
- return self.forward(x=x,
- layer=layer,
- router_logits=router_logits,
- top_k=top_k,
- renormalize=renormalize,
- use_grouped_topk=use_grouped_topk,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- e_score_correction_bias=e_score_correction_bias,
- activation=activation)
+ return self.forward(
+ x=x,
+ layer=layer,
+ router_logits=router_logits,
+ top_k=top_k,
+ renormalize=renormalize,
+ use_grouped_topk=use_grouped_topk,
+ topk_group=topk_group,
+ num_expert_group=num_expert_group,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ custom_routing_function=custom_routing_function,
+ scoring_func=scoring_func,
+ e_score_correction_bias=e_score_correction_bias,
+ activation=activation,
+ apply_router_weight_on_input=apply_router_weight_on_input)
def forward_cuda(
self,
@@ -188,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -202,15 +208,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
- return fused_experts(hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map)
+ return fused_experts(
+ hidden_states=x,
+ w1=layer.w13_weight,
+ w2=layer.w2_weight,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=True,
+ activation=activation,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map)
def forward_cpu(
self,
@@ -228,9 +236,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
**kwargs,
):
assert activation == "silu", f"{activation} is not supported."
+ assert apply_router_weight_on_input is False
return layer.ipex_fusion(
x,
use_grouped_topk,
@@ -259,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
@@ -266,6 +277,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert topk_group is None
assert custom_routing_function is None
assert layer is not None
+ assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for HPU.")
@@ -290,12 +302,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
+ assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU.")
@@ -401,6 +415,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
super().__init__()
@@ -486,6 +501,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
+ self.apply_router_weight_on_input = apply_router_weight_on_input
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
@@ -853,6 +869,7 @@ class FusedMoE(torch.nn.Module):
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
+ apply_router_weight_on_input=self.apply_router_weight_on_input,
)
if self.dp_size > 1:
diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py
index 76d3acb92fb8..5e8eb6c54c89 100644
--- a/vllm/model_executor/layers/layernorm.py
+++ b/vllm/model_executor/layers/layernorm.py
@@ -92,6 +92,7 @@ class RMSNorm(CustomOp):
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
+ dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
@@ -100,8 +101,10 @@ class RMSNorm(CustomOp):
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.has_weight = has_weight
-
- self.weight = torch.ones(hidden_size)
+ if dtype is not None:
+ self.weight = torch.ones(hidden_size, dtype=dtype)
+ else:
+ self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index 473816fcc3ec..cb1d5400f3a0 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
@@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
+ if apply_router_weight_on_input:
+ raise NotImplementedError(
+ "Apply router weight on input is not supported for"
+ "fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index bf32bee89e89..f573c8ae5131 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -240,20 +241,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
- return fused_experts(x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- activation=activation,
- use_fp8_w8a8=True,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- a1_scale=layer.w13_input_scale,
- a2_scale=layer.w2_input_scale)
+ return fused_experts(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=True,
+ activation=activation,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ use_fp8_w8a8=True,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ w1_scale=layer.w13_weight_scale,
+ w2_scale=layer.w2_weight_scale,
+ a1_scale=layer.w13_input_scale,
+ a2_scale=layer.w2_input_scale)
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
@@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
@@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
+ apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
@@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
+ if apply_router_weight_on_input:
+ raise NotImplementedError(
+ "Apply router weight on input is not supported for "
+ "fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py
index d18ca55afebd..be19b80975ec 100644
--- a/vllm/model_executor/layers/quantization/experts_int8.py
+++ b/vllm/model_executor/layers/quantization/experts_int8.py
@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
- return fused_experts(x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- activation=activation,
- use_int8_w8a16=True,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- w1_scale=layer.w13_scale,
- w2_scale=layer.w2_scale)
+ return fused_experts(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=True,
+ activation=activation,
+ use_int8_w8a16=True,
+ global_num_experts=global_num_experts,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ expert_map=expert_map,
+ w1_scale=layer.w13_scale,
+ w2_scale=layer.w2_scale)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index e7c733db5c00..4435644c4f84 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
+ apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index 9861e0a85b3f..6b499f81c55f 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
+ if apply_router_weight_on_input:
+ raise NotImplementedError(
+ "Apply router weight on input is not supported for"
+ "fused GGUF MoE method.")
+
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 9f53ffc1d7f6..0615bb4ab4df 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
+ if apply_router_weight_on_input is not None:
+ raise NotImplementedError(
+ "Apply router weight on input is not supported for"
+ "fused Marlin MoE method.")
# The input must currently be float16
orig_dtype = x.dtype
diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py
index 41b75c9be05a..00c4b661ef2c 100644
--- a/vllm/model_executor/layers/quantization/moe_wna16.py
+++ b/vllm/model_executor/layers/quantization/moe_wna16.py
@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -312,21 +313,23 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
- return fused_experts(x,
- layer.w13_qweight,
- layer.w2_qweight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- use_int4_w4a16=weight_bits == 4,
- use_int8_w8a16=weight_bits == 8,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- w1_scale=layer.w13_scales,
- w2_scale=layer.w2_scales,
- w1_zp=layer.w13_qzeros if has_zp else None,
- w2_zp=layer.w2_qzeros if has_zp else None,
- block_shape=[0, layer.group_size])
+ return fused_experts(
+ x,
+ layer.w13_qweight,
+ layer.w2_qweight,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=True,
+ use_int4_w4a16=weight_bits == 4,
+ use_int8_w8a16=weight_bits == 8,
+ global_num_experts=global_num_experts,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ expert_map=expert_map,
+ w1_scale=layer.w13_scales,
+ w2_scale=layer.w2_scales,
+ w1_zp=layer.w13_qzeros if has_zp else None,
+ w2_zp=layer.w2_qzeros if has_zp else None,
+ block_shape=[0, layer.group_size])
@staticmethod
def get_weight_loader(layer, weight_loader):
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index bc26a455c6f2..d1146c0f039d 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
+ activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
- return fused_experts(x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- use_fp8_w8a8=True,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- a1_scale=layer.w13_input_scale,
- a2_scale=layer.w2_input_scale)
+ return fused_experts(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=True,
+ use_fp8_w8a8=True,
+ global_num_experts=global_num_experts,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ expert_map=expert_map,
+ w1_scale=layer.w13_weight_scale,
+ w2_scale=layer.w2_weight_scale,
+ a1_scale=layer.w13_input_scale,
+ a2_scale=layer.w2_input_scale)
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index fd27775b7dc0..624ed63ab8b4 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs
+class Llama4VisionRotaryEmbedding(RotaryEmbedding):
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: int,
+ is_neox_style: bool,
+ dtype: torch.dtype,
+ ):
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
+ is_neox_style, dtype)
+
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
+ inv_freqs = super()._compute_inv_freq(base)
+ inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
+ return inv_freqs
+
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
+ inv_freq = self._compute_inv_freq(self.base)
+
+ # self.max_position_embeddings here is number of image patches
+ # i.e. (image_size // patch_size) ** 2
+ num_patches = self.max_position_embeddings
+ img_idx = torch.arange(num_patches,
+ dtype=torch.int32) \
+ .reshape(num_patches, 1)
+ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
+ img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
+ num_patches_single_dim = int(math.sqrt(num_patches))
+ frequencies_x = img_idx % num_patches_single_dim
+ frequencies_y = img_idx // num_patches_single_dim
+ freqs_x = ((frequencies_x + 1)[..., None] *
+ inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
+ freqs_y = ((frequencies_y + 1)[..., None] *
+ inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
+ freqs = torch.cat([freqs_x, freqs_y],
+ dim=-1).float().contiguous()[..., ::2]
+ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
+ cache = torch.view_as_complex(
+ torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
+ return cache
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
+ query_ = torch.view_as_complex(query.float().reshape(
+ *query.shape[:-1], -1, 2))
+ key_ = torch.view_as_complex(key.float().reshape(
+ *key.shape[:-1], -1, 2))
+ broadcast_shape = [
+ d if i == 1 or i == (query_.ndim - 1) else 1
+ for i, d in enumerate(query_.shape)
+ ]
+ freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
+ query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
+ key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
+ return query_out.type_as(query), key_out.type_as(key)
+
+
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
@@ -1130,6 +1194,10 @@ def get_rope(
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
+ elif scaling_type == "mllama4":
+ rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
+ max_position, base,
+ is_neox_style, dtype)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 5649cf2dd2cf..7e43438851d1 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -111,10 +111,12 @@ def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
+ model_class: Optional[type[nn.Module]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
- model_class, _ = get_model_architecture(model_config)
+ if model_class is None:
+ model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index 81b5d9bda9ac..caa4a5108a92 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
+from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
+ reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
+ reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
@@ -292,7 +294,7 @@ class LlamaModel(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
- layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
- "norm": "model.norm"
+ "norm": "model.norm",
}
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ def __init__(self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
- prefix=maybe_prefix(prefix, "model"))
+ prefix=maybe_prefix(prefix, "model"),
+ layer_type=layer_type)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
- def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
- return LlamaModel(vllm_config=vllm_config, prefix=prefix)
+ def _init_model(self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
+ return LlamaModel(vllm_config=vllm_config,
+ prefix=prefix,
+ layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py
new file mode 100644
index 000000000000..029f6044598c
--- /dev/null
+++ b/vllm/model_executor/models/llama4.py
@@ -0,0 +1,531 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
+# All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only LLaMA model compatible with HuggingFace weights."""
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
+
+import torch
+from torch import nn
+from transformers import Llama4TextConfig
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig
+from vllm.distributed import (get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
+from .utils import (AutoWeightsLoader, extract_layer_index,
+ is_pp_missing_parameter)
+
+
+class Llama4MoE(nn.Module):
+
+ @staticmethod
+ def custom_routing_function(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
+ router_scores = torch.sigmoid(router_scores.float()).to(
+ hidden_states.dtype)
+ return (router_scores, router_indices.to(torch.int32))
+
+ def __init__(self,
+ config: Llama4TextConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.top_k = config.num_experts_per_tok
+
+ intermediate_size_moe = config.intermediate_size
+ self.router = ReplicatedLinear(config.hidden_size,
+ config.num_local_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.router")
+
+ self.experts = FusedMoE(
+ num_experts=config.num_local_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ custom_routing_function=Llama4MoE.custom_routing_function,
+ intermediate_size=intermediate_size_moe,
+ apply_router_weight_on_input=True,
+ reduce_results=False,
+ renormalize=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts")
+
+ self.shared_expert = LlamaMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size_moe,
+ hidden_act="silu",
+ quant_config=quant_config,
+ bias=False,
+ prefix=f"{prefix}.shared_expert",
+ reduce_results=False, # We need to do scatter before reduce
+ )
+
+ def forward(self, hidden_states):
+ router_logits, _ = self.router(hidden_states)
+ shared_out = self.shared_expert(hidden_states)
+ routed_out = self.experts(
+ hidden_states=hidden_states,
+ router_logits=router_logits,
+ )
+ experts_out = routed_out + shared_out
+
+ if self.tp_size > 1:
+ experts_out = tensor_model_parallel_all_reduce(experts_out)
+
+ return experts_out
+
+
+class Llama4Attention(nn.Module):
+
+ def __init__(self,
+ config: Llama4TextConfig,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[Dict[str, Any]] = None,
+ max_position_embeddings: int = 8192,
+ quant_config: Optional[QuantizationConfig] = None,
+ bias: bool = False,
+ bias_o_proj: bool = False,
+ cache_config: Optional[CacheConfig] = None,
+ prefix: str = "") -> None:
+ super().__init__()
+ self.layer_idx = extract_layer_index(prefix)
+ self.hidden_size = hidden_size
+ self.no_rope_layers = config.no_rope_layers
+ self.nope = self.no_rope_layers[self.layer_idx] == 0
+ self.use_qk_norm = config.use_qk_norm and not self.nope
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = config.head_dim
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ # TODO: attn_temperature_tuning should be a bool in huggingface
+ self.attn_temperature_tuning = self.nope and \
+ config.attn_temperature_tuning > 0
+
+ self.floor_scale = getattr(config, "floor_scale", 8192.0)
+ self.attn_scale = getattr(config, "attn_scale", 0.1)
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.n_rep = self.num_heads // self.num_kv_heads
+ self.q_norm = RMSNorm(
+ hidden_size=self.q_size,
+ eps=config.rms_norm_eps,
+ has_weight=False,
+ dtype=torch.float32,
+ ) if self.use_qk_norm else None
+ self.k_norm = RMSNorm(
+ hidden_size=self.kv_size,
+ eps=config.rms_norm_eps,
+ has_weight=False,
+ dtype=torch.float32,
+ ) if self.use_qk_norm else None
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size=hidden_size,
+ head_size=self.head_dim,
+ total_num_heads=self.total_num_heads,
+ total_num_kv_heads=self.total_num_kv_heads,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.o_proj = RowParallelLinear(
+ input_size=self.total_num_heads * self.head_dim,
+ output_size=hidden_size,
+ bias=bias_o_proj,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+ is_neox_style = True
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
+ if is_gguf and config.model_type == "llama":
+ is_neox_style = False
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=int(rope_theta),
+ rope_scaling=rope_scaling if rope_scaling != "default" else None,
+ is_neox_style=is_neox_style,
+ ) if not self.nope else None
+
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ per_layer_sliding_window=None,
+ use_irope=not self.nope,
+ prefix=f"{prefix}.attn",
+ )
+
+ def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
+ floor = torch.floor((positions + 1.0) / self.floor_scale)
+ attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
+
+ return attn_scale.unsqueeze(-1)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ if self.rotary_emb is not None:
+ q, k = self.rotary_emb(positions, q, k)
+ if self.q_norm is not None:
+ q = self.q_norm(q.float()).to(q.dtype)
+ if self.k_norm is not None:
+ k = self.k_norm(k.float()).to(k.dtype)
+
+ # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
+ # to NoPE layers, where the inference-time temperature tuning function
+ # is customized to not affect short context
+ # while working at very long context
+ # https://arxiv.org/abs/2501.19399
+ #
+ # We should apply temperature tuning between (after) rotary / QK norm
+ # and (before) attention.
+ if self.attn_temperature_tuning and self.nope:
+ attn_scale = self._get_attn_scale(positions)
+ q = (q * attn_scale).to(q.dtype)
+ attn_output = self.attn(q, k, v)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Llama4DecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: Llama4TextConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.layer_idx = extract_layer_index(prefix)
+ self.hidden_size = config.hidden_size
+ rope_theta = config.rope_theta
+ rope_scaling = config.rope_scaling
+ max_position_embeddings = config.max_position_embeddings
+
+ self.self_attn = Llama4Attention(
+ config=config,
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ quant_config=quant_config,
+ bias=False,
+ bias_o_proj=False,
+ cache_config=cache_config,
+ prefix=f"{prefix}.self_attn",
+ )
+ is_moe_layer = (self.layer_idx +
+ 1) % config.interleave_moe_layer_step == 0
+ if is_moe_layer:
+ self.feed_forward = Llama4MoE(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.feed_forward",
+ )
+ else:
+ self.feed_forward = LlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size_mlp,
+ hidden_act="silu",
+ quant_config=quant_config,
+ bias=False,
+ prefix=f"{prefix}.feed_forward",
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Self Attention
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(positions=positions,
+ hidden_states=hidden_states)
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+ hidden_states = self.feed_forward(hidden_states)
+ return hidden_states, residual
+
+
+@support_torch_compile
+class Llama4Model(LlamaModel):
+
+ def __init__(self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
+ self.num_experts = vllm_config.model_config.hf_config.num_local_experts
+ super().__init__(vllm_config=vllm_config,
+ prefix=prefix,
+ layer_type=layer_type)
+
+ def load_moe_expert_weights(
+ self,
+ name: str,
+ loaded_weight: torch.Tensor,
+ params_dict: Dict[str, nn.Parameter],
+ loaded_params: Set[str],
+ expert_params_mapping: List[Tuple[str, str, int, str]],
+ fused: bool = True,
+ ) -> bool:
+ expert_param_loaded = False
+ if "experts.gate_up_proj" in name:
+ loaded_weight = loaded_weight.chunk(2, dim=-1)
+ for (param_name, weight_name, expert_id,
+ shard_id) in expert_params_mapping:
+ new_loaded_weight = loaded_weight
+ if fused:
+ e_str, _, proj_str, _ = weight_name.split('.')
+ weight_name = f"{e_str}.{proj_str}"
+ param_name = f"{param_name}weight"
+ if weight_name not in name:
+ continue
+ full_param_name = name.replace(weight_name, param_name)
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ if ((name.endswith(".bias") or name.endswith("_bias"))
+ and name not in params_dict):
+ continue
+ param = params_dict[full_param_name]
+ weight_loader = param.weight_loader
+ if fused:
+ if "w13" in full_param_name:
+ shard_idx = 0 if shard_id == "w1" else 1
+ new_loaded_weight = new_loaded_weight[shard_idx]
+ new_loaded_weight = new_loaded_weight.transpose(-1, -2)
+ layer_idx = extract_layer_index(name)
+ # EP mapping
+ expert_map = self.layers[
+ layer_idx].feed_forward.experts.expert_map
+ if expert_map is not None:
+ local_expert_indices = (expert_map != -1) \
+ .nonzero() \
+ .flatten() \
+ .to(new_loaded_weight.device)
+ new_loaded_weight = new_loaded_weight[local_expert_indices]
+ expert_id = local_expert_indices[0].item()
+ else:
+ # TODO: add EP support for non fused weights
+ pass
+ weight_loader(param,
+ new_loaded_weight,
+ full_param_name,
+ shard_id=shard_id,
+ expert_id=expert_id)
+
+ loaded_params.add(full_param_name)
+ expert_param_loaded = True
+ return expert_param_loaded
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv_proj", ".q_proj", "q"),
+ (".qkv_proj", ".k_proj", "k"),
+ (".qkv_proj", ".v_proj", "v"),
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
+ fused_experts_params = False
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.num_experts)
+ expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_up_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="gate_up_proj",
+ num_experts=1)
+ params_dict = dict(self.named_parameters())
+ loaded_params: Set[str] = set()
+ for name, loaded_weight in weights:
+ if "experts.gate_up_proj" in name or "experts.down_proj" in name:
+ fused_experts_params = True
+ expert_params_mapping = expert_params_mapping_fused
+ if (self.quant_config is not None and
+ (scale_name := self.quant_config.get_cache_scale(name))):
+ # Loading kv cache quantization scales
+ param = params_dict[scale_name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
+ loaded_weight[0])
+ weight_loader(param, loaded_weight)
+ loaded_params.add(scale_name)
+ continue
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name or "experts" in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(name)
+ break
+ else:
+ moe_loaded = self.load_moe_expert_weights(
+ name,
+ loaded_weight,
+ params_dict,
+ loaded_params,
+ expert_params_mapping,
+ fused=fused_experts_params)
+
+ if not moe_loaded:
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Llama4ForCausalLM(LlamaForCausalLM):
+
+ packed_modules_mapping = {
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ # Update temperature tuning config from generation config
+ gen_config = vllm_config.model_config.try_get_generation_config()
+ gen_config.update(vllm_config.model_config.override_generation_config)
+ vllm_config.model_config.hf_config.attn_temperature_tuning \
+ = gen_config.get("attn_temperature_tuning", False)
+
+ super().__init__(vllm_config=vllm_config,
+ prefix=prefix,
+ layer_type=Llama4DecoderLayer)
+
+ def _init_model(self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
+ return Llama4Model(vllm_config=vllm_config,
+ prefix=prefix,
+ layer_type=layer_type)
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."]
+ if self.config.tie_word_embeddings else None),
+ )
+ weights = [
+ self.permute_qk_weight_for_rotary(name, loaded_weight)
+ for name, loaded_weight in weights
+ ]
+ return loader.load_weights(weights)
+
+ def permute_qk_weight_for_rotary(
+ self,
+ name: str,
+ loaded_weight: torch.Tensor,
+ ) -> Tuple[str, torch.Tensor]:
+
+ def permute(w: torch.Tensor, n_heads: int):
+ attn_in = self.config.head_dim * n_heads
+ attn_out = self.config.hidden_size
+
+ return w.view(n_heads, attn_in // n_heads // 2, 2,
+ attn_out).transpose(1, 2).reshape(attn_in, attn_out)
+
+ modules = name.split(".")
+
+ # rotary embeds should be sliced
+ if ("wk" in modules or "k_proj" in modules) \
+ and modules[-1] == "weight":
+ loaded_weight = permute(loaded_weight,
+ self.config.num_key_value_heads)
+ elif ("wq" in modules or "q_proj" in modules) \
+ and modules[-1] == "weight":
+ loaded_weight = permute(loaded_weight,
+ self.config.num_attention_heads)
+
+ return name, loaded_weight
diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py
new file mode 100644
index 000000000000..dae98093bc6e
--- /dev/null
+++ b/vllm/model_executor/models/mllama4.py
@@ -0,0 +1,895 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
+# All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from collections.abc import Iterable, Mapping
+from functools import cached_property
+from itertools import tee
+from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
+
+import torch
+from torch import nn
+from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
+from transformers.image_utils import SizeDict
+from transformers.models.llama4 import Llama4Processor
+from transformers.models.llama4.image_processing_llama4_fast import (
+ find_supported_resolutions, get_best_fit)
+
+from vllm.attention.layer import MultiHeadAttention
+from vllm.config import VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.inputs import InputProcessingContext
+from vllm.logger import init_logger
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.model_loader.loader import _initialize_model
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
+ MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .llama4 import Llama4ForCausalLM
+from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
+ merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
+
+logger = init_logger(__name__)
+
+
+class Llama4ImagePatchInputs(TypedDict):
+ type: Literal["pixel_values"]
+ flat_data: torch.Tensor
+ """
+ Shape:
+ `(batch_size * num_chunks, num_channels, image size, image size)`
+ """
+ patches_per_image: torch.Tensor
+ """
+ The number of total patches for each image in the batch.
+
+ This is used to split the embeddings which has the first two dimensions
+ flattened just like `flat_data`.
+ """
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+ """
+ aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A list of aspect ratios corresponding to the number of tiles
+ in each dimension that each image in the batch corresponds to.
+
+ Shape:
+ `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
+ """
+
+
+class Llama4VisionMLP(nn.Module):
+
+ def __init__(self,
+ input_size: int,
+ intermediate_size: int,
+ output_size: int,
+ bias: bool,
+ output_activation: bool,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.fc1 = ColumnParallelLinear(
+ input_size=input_size,
+ output_size=intermediate_size,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc1",
+ )
+ self.fc2 = RowParallelLinear(
+ input_size=intermediate_size,
+ output_size=output_size,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc2",
+ )
+ self.activation_fn = nn.GELU()
+ self.output_activation = output_activation
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states, _ = self.fc2(hidden_states)
+ if self.output_activation:
+ return self.activation_fn(hidden_states)
+ return hidden_states
+
+
+class Llama4MultiModalProjector(nn.Module):
+
+ def __init__(
+ self,
+ config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.linear_1 = ColumnParallelLinear(
+ input_size=config.vision_config.vision_output_dim,
+ output_size=config.text_config.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ gather_output=True,
+ prefix=f"{prefix}.linear_1",
+ )
+
+ def forward(self, image_features):
+ hidden_states, _ = self.linear_1(image_features)
+ return hidden_states
+
+
+def pixel_shuffle(input_tensor, shuffle_ratio):
+ # input_tensor: [batch_size, num_patches, channels]
+ batch_size, num_patches, channels = input_tensor.shape
+ patch_size = int(math.sqrt(num_patches))
+
+ input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
+ batch_size, height, width, channels = input_tensor.size()
+
+ reshaped_tensor = input_tensor.view(batch_size, height,
+ int(width * shuffle_ratio),
+ int(channels / shuffle_ratio))
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
+
+ reshaped_tensor = reshaped_tensor.view(batch_size,
+ int(height * shuffle_ratio),
+ int(width * shuffle_ratio),
+ int(channels / (shuffle_ratio**2)))
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
+
+ output_tensor = reshaped_tensor.view(batch_size, -1,
+ reshaped_tensor.shape[-1])
+ return output_tensor
+
+
+class Llama4VisionPixelShuffleMLP(nn.Module):
+
+ def __init__(
+ self,
+ config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
+ self.inner_dim = int(config.projector_input_dim //
+ (self.pixel_shuffle_ratio**2))
+ self.output_dim = config.projector_output_dim
+ self.mlp = Llama4VisionMLP(
+ input_size=config.intermediate_size,
+ intermediate_size=config.projector_input_dim,
+ output_size=config.projector_output_dim,
+ bias=config.multi_modal_projector_bias,
+ output_activation=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
+ encoded_patches = pixel_shuffle(encoded_patches,
+ self.pixel_shuffle_ratio)
+ return self.mlp(encoded_patches)
+
+
+class Llama4VisionAttention(nn.Module):
+
+ def __init__(
+ self,
+ config: Llama4VisionConfig,
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.hidden_size // self.num_heads
+ assert self.num_heads % self.tp_size == 0
+ self.num_local_heads = self.num_heads // self.tp_size
+ self.q_size = self.num_local_heads * self.head_dim
+ self.kv_size = self.num_local_heads * self.head_dim
+ self.attention_dropout = config.attention_dropout
+ self.scaling = self.head_dim**-0.5
+
+ self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
+ self.scaling)
+ self.qkv_proj = QKVParallelLinear(
+ self.embed_dim,
+ self.head_dim,
+ self.num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ self.o_proj = RowParallelLinear(
+ self.num_heads * self.head_dim,
+ self.embed_dim,
+ bias=True,
+ input_is_parallel=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ self.rotary_emb = get_rope(
+ head_size=self.head_dim,
+ rotary_dim=config.hidden_size // config.num_attention_heads // 2,
+ # number of image patches
+ max_position=(config.image_size // config.patch_size)**2,
+ base=config.rope_theta,
+ rope_scaling={"rope_type": "mllama4"},
+ is_neox_style=False,
+ dtype=torch.complex64, # important
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ input_shape = hidden_states.shape[:-1]
+
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
+ k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
+ q, k = self.rotary_emb(q, k)
+
+ q = q.view(q.shape[0], q.shape[1], -1)
+ k = k.view(k.shape[0], k.shape[1], -1)
+
+ attn_output = self.attn(q, k, v)
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output, _ = self.o_proj(attn_output)
+
+ return attn_output
+
+
+class Llama4VisionEncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: Llama4VisionConfig,
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.intermediate_size = config.intermediate_size
+
+ self.self_attn = Llama4VisionAttention(config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn")
+ self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ output_size=config.hidden_size,
+ bias=True,
+ output_activation=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ self.input_layernorm = nn.LayerNorm(config.hidden_size)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ ):
+ # Self Attention
+ residual = hidden_state
+ hidden_state = self.input_layernorm(hidden_state)
+ hidden_state = self.self_attn(hidden_state)
+ hidden_state = residual + hidden_state
+
+ # Feed forward
+ residual = hidden_state
+ hidden_state = self.post_attention_layernorm(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ hidden_state = residual + hidden_state
+
+ outputs = (hidden_state, )
+ return outputs
+
+
+class Llama4VisionEncoder(nn.Module):
+
+ def __init__(
+ self,
+ config: Llama4VisionConfig,
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ Llama4VisionEncoderLayer(
+ config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.layers.{layer_idx}",
+ ) for layer_idx in range(config.num_hidden_layers)
+ ])
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape
+ `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to
+ directly pass an embedded representation. This is useful if you
+ want more control over how to convert `input_ids` indices into
+ associated vectors than the model's internal embedding
+ lookup matrix.
+ """
+
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(hidden_states)
+ hidden_states = layer_outputs[0]
+
+ return hidden_states
+
+
+class Llama4UnfoldConvolution(nn.Module):
+
+ def __init__(self,
+ config: Llama4VisionConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ kernel_size = config.patch_size
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+ self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
+ stride=config.patch_size)
+ self.linear = ColumnParallelLinear(config.num_channels *
+ kernel_size[0] * kernel_size[1],
+ config.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ gather_output=True,
+ prefix=f"{prefix}.linear")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.unfold(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 1)
+ hidden_states, _ = self.linear(hidden_states)
+ return hidden_states
+
+
+class Llama4VisionModel(nn.Module):
+
+ def __init__(
+ self,
+ config: Llama4VisionConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+
+ self.num_patches = (self.image_size // self.patch_size)**2 + 1
+ self.scale = config.hidden_size**-0.5
+
+ self.patch_embedding = Llama4UnfoldConvolution(
+ config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.patch_embedding")
+
+ self.class_embedding = nn.Parameter(self.scale *
+ torch.randn(self.hidden_size))
+ self.positional_embedding_vlm = nn.Parameter(
+ self.scale * torch.randn(self.num_patches, self.hidden_size))
+
+ # layer norms
+ self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
+ self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
+
+ # encoders
+ self.model = Llama4VisionEncoder(config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.model")
+ self.vision_adapter = Llama4VisionPixelShuffleMLP(
+ config, quant_config, prefix=f"{prefix}.vision_adapter")
+
+ def forward(
+ self,
+ images_flattened: torch.Tensor,
+ ) -> torch.Tensor:
+ # Patch embedding
+ hidden_state = self.patch_embedding(images_flattened)
+ num_tiles, num_patches, hidden_dim = hidden_state.shape
+
+ # Add cls token
+ class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1,
+ hidden_state.shape[-1])
+ hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
+ num_patches += 1
+
+ # Position embeddings
+ hidden_state = hidden_state.reshape(
+ num_tiles,
+ 1,
+ num_patches,
+ hidden_dim,
+ )
+ positional_embedding = self.positional_embedding_vlm.to(
+ dtype=hidden_state.dtype, device=hidden_state.device)
+ hidden_state = hidden_state + positional_embedding
+ hidden_state = self.layernorm_pre(hidden_state)
+ hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
+
+ # Apply encoder
+ hidden_state = self.model(hidden_state)
+ hidden_state = self.layernorm_post(hidden_state)
+
+ # Remove CLS token output
+ hidden_state = hidden_state[:, :-1, :]
+
+ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
+ hidden_state = self.vision_adapter(hidden_state)
+
+ return hidden_state
+
+
+class Mllama4ProcessingInfo(BaseProcessingInfo):
+
+ def __init__(self, ctx: InputProcessingContext) -> None:
+ super().__init__(ctx)
+
+ def get_hf_config(self) -> Llama4Config:
+ return self.ctx.get_hf_config(Llama4Config)
+
+ def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
+ return self.ctx.get_hf_processor(Llama4Processor,
+ use_fast=True,
+ **kwargs)
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": 10}
+
+ @staticmethod
+ def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
+ image_size = vision_config.image_size
+ patch_size = vision_config.patch_size
+
+ assert (
+ image_size %
+ patch_size == 0), f"chunk size {image_size} should be multiple of "
+ f"patch_size {patch_size}"
+
+ ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
+ return (image_size // patch_size)**2 // ds_ratio
+
+ def get_max_num_tiles(self) -> int:
+ image_processor = self.get_hf_processor().image_processor
+ return image_processor.max_patches
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ vision_config = self.get_hf_config().vision_config
+ # image_start + local tiles * (patches + 1 x separator) +
+ # 1 global tile * (image x 1 + patches) + image_end
+ token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
+ mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2
+ return {"image": mm_max_tokens}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ vision_config = self.get_hf_config().vision_config
+ image_size = vision_config.image_size
+ # Result in the max possible feature size (h:w = 16:1)
+ return ImageSize(height=self.get_max_num_tiles() * image_size,
+ width=image_size)
+
+
+class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
+ ):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ tokenizer = self.info.get_tokenizer()
+
+ if mm_data is None:
+ return tokenizer(prompt, add_special_tokens=False) # exclude bos
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ )
+
+ processor = self.info.get_hf_processor(**mm_kwargs)
+ image_processor = processor.image_processor
+ vision_config = self.info.get_hf_config().vision_config
+
+ if processed_outputs.get("pixel_values") is not None:
+ assert "images" in mm_data, \
+ "images expected to be in mm_data when pixel_values is present"
+
+ images = mm_data["images"]
+ parsed_images = (self._get_data_parser().parse_mm_data({
+ "image":
+ images
+ }).get_items("image", ImageProcessorItems))
+
+ tile_size = vision_config.image_size
+ possible_resolutions = find_supported_resolutions(
+ max_num_chunks=self.info.get_max_num_tiles(),
+ patch_size=SizeDict(height=tile_size, width=tile_size),
+ )
+ best_fit_sizes = [
+ get_best_fit(
+ (image.size[1], image.size[0]),
+ torch.tensor(possible_resolutions),
+ resize_to_max_canvas=image_processor.resize_to_max_canvas)
+ for image in parsed_images
+ ]
+ # TODO tile height/width do not necessarily need to match
+ aspect_ratios = [(image_size[0] // tile_size,
+ image_size[1] // tile_size)
+ for image_size in best_fit_sizes]
+ patches_per_image = [
+ 1 if r_h * r_w == 1 else 1 + r_h * r_w
+ for (r_h, r_w) in aspect_ratios
+ ]
+
+ # embed_is_patch should have one feature per image-related token:
+ # <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
+ # -> False
+ # <|patch|> -> True
+ # embed_is_patch has no entries corresponding to non-image-related
+ # tokens.
+ patch_id = tokenizer.get_vocab()[processor.img_patch_token]
+ num_patches_per_chunk = self.info.get_patch_per_chunk(
+ vision_config)
+ expanded_image_tokens_list = [
+ processor._prompt_split_image(aspect_ratio,
+ num_patches_per_chunk)
+ for aspect_ratio in aspect_ratios
+ ]
+ expanded_image_token_ids = [
+ tokenizer.encode(image_tokens, add_special_tokens=False)
+ for image_tokens in expanded_image_tokens_list
+ ]
+ embed_is_patch = [
+ torch.tensor(tokens) == patch_id
+ for tokens in expanded_image_token_ids
+ ]
+
+ processed_outputs["aspect_ratios"] = aspect_ratios
+ processed_outputs["patches_per_image"] = torch.tensor(
+ patches_per_image)
+ processed_outputs["embed_is_patch"] = embed_is_patch
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", patches_per_image),
+ patches_per_image=MultiModalFieldConfig.batched("image"),
+ aspect_ratios=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> List[PromptUpdate]:
+ assert (
+ mm_items.get_count("image", strict=False) == 0
+ or "aspect_ratios" in out_mm_kwargs
+ ), "Transformers expect to include aspect_ratios in out_mm_kwargs"
+
+ config = self.info.get_hf_config()
+ vision_config = config.vision_config
+
+ num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_token = hf_processor.image_token
+
+ def get_replacement(item_idx: int):
+ aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
+ return hf_processor._prompt_split_image(
+ aspect_ratio=aspect_ratio,
+ num_patches_per_chunk=num_patches_per_chunk)
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=get_replacement,
+ )
+ ]
+
+
+class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ num_images = mm_counts.get("image", 0)
+
+ (target_width,
+ target_height) = self.info.get_image_size_with_most_features()
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+
+ image_token = self.info.get_hf_processor().fake_image_token
+ return ProcessorInputs(
+ prompt_text=image_token * num_images,
+ mm_data=mm_data,
+ )
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Mllama4MultiModalProcessor,
+ info=Mllama4ProcessingInfo,
+ dummy_inputs=Mllama4DummyInputsBuilder,
+)
+class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ self.config = config
+ self.quant_config = quant_config
+ self.multimodal_config = multimodal_config
+ self.vision_model = Llama4VisionModel(config.vision_config,
+ None,
+ prefix=maybe_prefix(
+ prefix, "vision_model"))
+ self.multi_modal_projector = Llama4MultiModalProjector(
+ self.config,
+ None,
+ prefix=maybe_prefix(prefix, "multi_modal_projector"))
+
+ self.language_model = _initialize_model(
+ vllm_config=vllm_config.with_hf_config(config.text_config),
+ prefix=maybe_prefix(prefix, "language_model"),
+ model_class=Llama4ForCausalLM,
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+
+ return get_sampler()
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
+ # num_images, 1, num_chunks, channel, image_size, image_size
+ pixel_values = kwargs.pop("pixel_values", None)
+ if pixel_values is None:
+ return None
+
+ # num_images x num_chunks, channel, image_size, image_size
+ # TODO: confirm handling for variable lengths
+ flat_pixel_values = flatten_bn(pixel_values, concat=True)
+ patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
+
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ aspect_ratios = kwargs.pop("aspect_ratios", None)
+ if not isinstance(aspect_ratios, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of aspect_ratios. "
+ f"Got type: {type(aspect_ratios)}")
+
+ return Llama4ImagePatchInputs(
+ type="pixel_values",
+ flat_data=flat_pixel_values,
+ patches_per_image=patches_per_image,
+ embed_is_patch=embed_is_patch,
+ aspect_ratios=aspect_ratios,
+ )
+
+ def _process_image_input(
+ self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
+ flat_data = image_input["flat_data"]
+ patches_per_image = image_input["patches_per_image"].tolist()
+ vision_embeddings_flat = self.vision_model(flat_data)
+ return vision_embeddings_flat.split(patches_per_image, dim=0)
+
+ def get_multimodal_embeddings(self,
+ **kwargs) -> Optional[MultiModalEmbeddings]:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+
+ # num_images x [num_chunks, num_patches, hidden_dim]
+ image_features = self._process_image_input(image_input)
+ # num_images x [num_chunks x num_patches, hidden_dim]
+ image_features_flat = [img.flatten(0, 1) for img in image_features]
+ # num_images x [1, input_len] -> num_images x [input_len]
+ embed_is_patch_flat = [
+ is_patch.flatten(0, 1)
+ for is_patch in image_input["embed_is_patch"]
+ ]
+
+ return scatter_patch_features(
+ image_features_flat,
+ embed_is_patch_flat,
+ )
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+
+ if multimodal_embeddings is not None:
+ multimodal_embeddings = torch.cat(multimodal_embeddings)
+ mm_embeddings = self.multi_modal_projector(multimodal_embeddings)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, select_patch_features(mm_embeddings),
+ self.config.image_token_index)
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner,
+ # this condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ return self.language_model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def sample(self, logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
+
+ def separate_weights(
+ self,
+ weights: Iterable[Tuple[str, torch.Tensor]],
+ prefix: str,
+ ) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[
+ str, torch.Tensor]]]:
+ weights1, weights2 = tee(weights, 2)
+
+ def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]:
+ for name, data in weights1:
+ if name.startswith(prefix):
+ yield (name, data)
+
+ def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]:
+ for name, data in weights2:
+ if not name.startswith(prefix):
+ yield (name, data)
+
+ return get_prefix_weights(), get_other_weights()
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
+ (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
+ (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters())
+ updated_params: Set[str] = set()
+
+ # language_model is an Llama4ForCausalLM instance. We load it's
+ # using llama4's load_weights routine.
+ language_model_weights, other_weights = self.separate_weights(
+ weights, prefix="language_model.model.")
+ loader = AutoWeightsLoader(self)
+ loaded_language_model_params = loader.load_weights(
+ language_model_weights)
+ assert loaded_language_model_params is not None
+ updated_params.update(loaded_language_model_params)
+
+ for name, loaded_weight in other_weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ updated_params.add(name)
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+
+ weight_loader(param, loaded_weight)
+ updated_params.add(name)
+ return updated_params
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 080aef8982d5..3abbb1f0c3b6 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -196,6 +196,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
+ "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py
index 062b1c2cf5f5..379e19e1beea 100644
--- a/vllm/model_executor/models/telechat2.py
+++ b/vllm/model_executor/models/telechat2.py
@@ -19,9 +19,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, Set, Tuple, Type
+from typing import Iterable, Set, Tuple
import torch
+import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -124,7 +125,7 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
- layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[Tuple[str,
diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py
index e670b1df08f7..e05f23f99e97 100644
--- a/vllm/model_executor/models/teleflm.py
+++ b/vllm/model_executor/models/teleflm.py
@@ -22,9 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Type
-
import torch
+import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -39,7 +38,7 @@ class TeleFLMModel(LlamaModel):
*,
vllm_config: VllmConfig,
prefix: str = "",
- layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer,
+ layer_type: type[nn.Module] = LlamaDecoderLayer,
):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index 92e4ffd0371a..1a8d2420db7a 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -96,6 +96,183 @@ class FlashAttentionMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
+ # for local attention
+ @dataclass
+ class LocalAttentionMetadata:
+ local_query_start_loc: torch.Tensor
+ local_seqused_k: torch.Tensor
+ local_block_table: torch.Tensor
+ local_max_query_len: int
+ local_max_seq_len: int
+
+ local_attn_metadata: Optional[LocalAttentionMetadata] = None
+
+
+#
+# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
+# local attention blocks, where each block is passed to the attention kernel
+# as an independent local ("virtual") batch item.
+#
+# For example, if are performing a chunked prefill a batch of 3 sequences:
+# q_seqlens = [4, 10, 5]
+# kv_seqlens = [6, 17, 9]
+# Then normally for regular attention we would compute with an attention mask
+# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
+# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
+# k_toks > 0 1 2 3 4 5
+# q_toks v _____________
+# 0 | 1 1 1
+# 1 | 1 1 1 1
+# 2 | 1 1 1 1 1
+# 3 | 1 1 1 1 1 1
+#
+# for local attention (with attn_chunk_size = 4) we would compute with an
+# attention mask like:
+# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
+# k_toks > 0 1 2 3 4 5
+# q_toks v _____________
+# 0 | 1 1 1
+# 1 | 1 1 1 1
+# 2 | 1
+# 3 | 1 1
+#
+# We can simulate this mask using standard flash-attention by breaking the
+# sequences into local ("virtual") batches, where each local batch item is a
+# local attention block, so in this case batch idx 0 would be broken up into:
+#
+# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
+# k_toks > 0 1 2 3
+# q_toks v _____________
+# 0 | 1 1 1
+# 1 | 1 1 1 1
+# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
+# k_toks > 4 5
+# q_toks v _____________
+# 2 | 1
+# 3 | 1 1
+#
+# e.g. if we have:
+# attn_chunk_size = 4
+# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
+# Then this function would return:
+# __b0__ ______b1______ __b2__ < orig batch indices
+# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
+# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
+# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
+# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
+def make_local_attention_virtual_batches(
+ attn_chunk_size: int,
+ query_start_loc_np: np.ndarray,
+ seq_lens_np: np.ndarray,
+ block_table: torch.Tensor,
+ page_size: int = 0,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
+ q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
+ actual_batch_size = seq_lens_np.shape[0]
+
+ # Handle if we are starting in the middle of a local attention block,
+ # we assume q_seqlens > 0 (for all elements), for each batch idx we compute
+ # the number of tokens that are not in the first local attention block and
+ # then we can simply use a cdiv for the rest.
+ # For example if we have:
+ # attn_chunk_size = 4
+ # q_seqlens = [4, 10, 5]
+ # k_seqlens = [6, 17, 9]
+ # Then we would get:
+ # new_tokens_in_first_block = [2, 1, 4]
+ # local_blocks = [2, 4, 2]
+ q_tokens_in_first_block = np.minimum(
+ attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
+ q_seqlens).astype(np.int32)
+ tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
+ local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
+ attn_chunk_size)
+
+ # Once we know the number of local blocks we can compute the request spans
+ # for each batch idx, we can figure out the number of "virtual" requests we
+ # have to make,
+ # For the above example we would get:
+ # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
+ #
+ # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
+ # (TODO: max a utility to share this code with _prepare_inputs)
+ # arange step 1. [2, 4, 2] -> [2, 6, 8]
+ cu_num_blocks = np.cumsum(local_blocks)
+ virtual_batches = cu_num_blocks[-1]
+ # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
+ block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
+ # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
+ arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
+ # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
+ rarange = np.repeat(local_blocks, local_blocks) - arange - 1
+ # Then we can compute the seqlens_q_local, handling the fact that the
+ # first and last blocks could be partial
+ seqlens_q_local = \
+ np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
+ # set the first block since this may be a partial block
+ seqlens_q_local[arange == 0] = q_tokens_in_first_block
+ # set the remaining blocks
+ seqlens_q_local[arange > 0] = np.minimum(
+ seqlens_q_local - attn_chunk_size * (arange - 1),
+ attn_chunk_size)[arange > 0]
+
+ # convert from q_seqlens to cu_seqlens_q
+ cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
+ .astype(np.int32)
+
+ # compute the seqlens_k_local,
+ # basically a full local attention block for all but the last block in each
+ # batch
+ # For our example this will be:
+ # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
+ seqlens_k_local = np.full(cu_num_blocks[-1],
+ attn_chunk_size,
+ dtype=np.int32)
+ seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
+
+ k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
+ (rarange * attn_chunk_size + \
+ np.repeat(tokens_in_last_block, local_blocks))
+ # For the example the local attention blocks start at:
+ # _b0_ _____b1_____ _b2_
+ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
+ block_starts = k_seqstarts_absolute // page_size
+ assert attn_chunk_size % page_size == 0, \
+ f"attn_chunk_size {attn_chunk_size} is not " \
+ f"divisible by page_size {page_size}"
+ pages_per_local_batch = attn_chunk_size // page_size
+
+ # Create a block_table for the local attention blocks
+ # For out example if we have a block-table like (assuming page_size=2):
+ # block_table = [
+ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
+ # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
+ # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
+ # ]
+ # Then for the local batches we would want a block-table like
+ # block_table_local = [
+ # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
+ # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
+ # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
+ # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
+ # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
+ # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
+ # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
+ # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
+ # ]
+ block_indices= np.broadcast_to(
+ np.arange(pages_per_local_batch, dtype=np.int32),
+ (virtual_batches, pages_per_local_batch)) \
+ + np.expand_dims(block_starts, axis=1)
+ block_indices = block_indices.flatten()
+ batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
+ local_blocks * pages_per_local_batch)
+ block_table_local = block_table[batch_indices, block_indices]\
+ .view(virtual_batches, -1)
+
+ return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
+ block_table_local
+
class FlashAttentionMetadataBuilder:
@@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
- query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
- self.runner.device, non_blocking=True)
- seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
- non_blocking=True)
+ query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
+ query_start_loc = query_start_loc_cpu.to(self.runner.device,
+ non_blocking=True)
+ seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
+ seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
+ # for local attention
+ local_attn_metadata = None
+ if self.runner.attention_chunk_size is not None:
+ seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
+ virt_block_table = make_local_attention_virtual_batches(
+ self.runner.attention_chunk_size,
+ self.runner.query_start_loc_np[:num_reqs + 1],
+ self.runner.seq_lens_np[:num_reqs],
+ block_table,
+ self.runner.block_size,
+ )
+ local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
+ local_query_start_loc=torch.from_numpy(
+ virt_q_cu_seqlens_np).to(self.runner.device,
+ non_blocking=True),
+ local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
+ self.runner.device, non_blocking=True),
+ local_block_table=virt_block_table,
+ local_max_query_len=seqlens_q_local_np.max(),
+ local_max_seq_len=virt_k_seqlens_np.max(),
+ )
+
use_cascade = common_prefix_len > 0
if use_cascade:
- # TODO: Optimize.
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
@@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
+ local_attn_metadata=local_attn_metadata,
)
return attn_metadata
@@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
+ use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
+ self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
@@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
- descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
- key.shape[1])
+
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
@@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
- if not attn_metadata.use_cascade:
- # Regular attention (common case).
+ use_local_attn = \
+ (self.use_irope and attn_metadata.local_attn_metadata is not None)
+
+ if not attn_metadata.use_cascade or use_local_attn:
+ if use_local_attn:
+ assert attn_metadata.local_attn_metadata is not None
+ local_metadata = attn_metadata.local_attn_metadata
+ cu_seqlens_q = local_metadata.local_query_start_loc
+ seqused_k = local_metadata.local_seqused_k
+ max_seqlen_q = local_metadata.local_max_query_len
+ max_seqlen_k = local_metadata.local_max_seq_len
+ block_table = local_metadata.local_block_table
+ else:
+ cu_seqlens_q = attn_metadata.query_start_loc
+ seqused_k = attn_metadata.seq_lens
+ max_seqlen_q = attn_metadata.max_query_len
+ max_seqlen_k = attn_metadata.max_seq_len
+ block_table = attn_metadata.block_table
+
+ descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
+
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
- cu_seqlens_q=attn_metadata.query_start_loc,
- max_seqlen_q=attn_metadata.max_query_len,
- seqused_k=attn_metadata.seq_lens,
- max_seqlen_k=attn_metadata.max_seq_len,
+ cu_seqlens_q=cu_seqlens_q,
+ max_seqlen_q=max_seqlen_q,
+ seqused_k=seqused_k,
+ max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
- block_table=attn_metadata.block_table,
+ block_table=block_table,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
@@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
+ assert not use_local_attn, (
+ "Cascade attention does not support local attention.")
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py
index 15b49b14c1dd..5f9610470567 100644
--- a/vllm/v1/attention/backends/triton_attn.py
+++ b/vllm/v1/attention/backends/triton_attn.py
@@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
+ use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
+ self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -156,24 +158,41 @@ class TritonAttentionImpl(AttentionImpl):
layer._v_scale,
)
+ use_local_attn = \
+ (self.use_irope and attn_metadata.local_attn_metadata is not None)
+
+ if use_local_attn:
+ assert attn_metadata.local_attn_metadata is not None
+ local_metadata = attn_metadata.local_attn_metadata
+ cu_seqlens_q = local_metadata.local_query_start_loc
+ sequesd_k = local_metadata.local_seqused_k
+ max_seqlen_q = local_metadata.local_max_query_len
+ max_seqlen_k = local_metadata.local_max_seq_len
+ block_table = local_metadata.local_block_table
+ else:
+ cu_seqlens_q = attn_metadata.query_start_loc
+ sequesd_k = attn_metadata.seq_lens
+ max_seqlen_q = attn_metadata.max_query_len
+ max_seqlen_k = attn_metadata.max_seq_len
+ block_table = attn_metadata.block_table
+
# Compute attention and update output up to `num_actual_tokens`.
- chunked_prefill_paged_decode(
- query=query[:num_actual_tokens],
- key=key[:num_actual_tokens],
- value=value[:num_actual_tokens],
- output=output[:num_actual_tokens],
- kv_cache_dtype=self.kv_cache_dtype,
- key_cache=key_cache,
- value_cache=value_cache,
- block_table=attn_metadata.block_table,
- query_start_loc=attn_metadata.query_start_loc,
- seq_lens=attn_metadata.seq_lens,
- max_seq_len=attn_metadata.max_seq_len,
- max_query_len=attn_metadata.max_query_len,
- k_scale=layer._k_scale,
- v_scale=layer._v_scale,
- alibi_slopes=self.alibi_slopes,
- sliding_window=self.sliding_window[0],
- sm_scale=self.scale)
+ chunked_prefill_paged_decode(query=query[:num_actual_tokens],
+ key=key[:num_actual_tokens],
+ value=value[:num_actual_tokens],
+ output=output[:num_actual_tokens],
+ kv_cache_dtype=self.kv_cache_dtype,
+ key_cache=key_cache,
+ value_cache=value_cache,
+ block_table=block_table,
+ query_start_loc=cu_seqlens_q,
+ seq_lens=sequesd_k,
+ max_seq_len=max_seqlen_k,
+ max_query_len=max_seqlen_q,
+ k_scale=layer._k_scale,
+ v_scale=layer._v_scale,
+ alibi_slopes=self.alibi_slopes,
+ sliding_window=self.sliding_window[0],
+ sm_scale=self.scale)
return output
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 82b07c6cd327..5133c637f0e0 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -113,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
+ self.attention_chunk_size = model_config.attention_chunk_size
self.attn_backend = get_attn_backend(
self.head_size,