diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 2bf0b6fd9a169..2f6cc45be77e6 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -82,7 +82,7 @@ steps:
- bash standalone_tests/python_only_compile.sh
- label: Basic Correctness Test # 30min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
fast_check: true
torch_nightly: true
source_file_dependencies:
@@ -99,7 +99,7 @@ steps:
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Chunked Prefill Test
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/basic_correctness/test_chunked_prefill
@@ -108,7 +108,7 @@ steps:
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test # 10min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
fast_check: true
source_file_dependencies:
- vllm/core
@@ -209,7 +209,7 @@ steps:
- pytest -v -s distributed/test_eplb_execute.py
- label: Metrics, Tracing Test # 10min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
num_gpus: 2
source_file_dependencies:
- vllm/
@@ -228,7 +228,7 @@ steps:
##### 1 GPU test #####
- label: Regression Test # 5min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/test_regression
@@ -280,7 +280,7 @@ steps:
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: Examples Test # 25min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/examples"
source_file_dependencies:
- vllm/entrypoints
@@ -305,7 +305,7 @@ steps:
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
- label: Prefix Caching Test # 9min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/prefix_caching
@@ -314,7 +314,7 @@ steps:
- label: Platform Tests (CUDA)
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/cuda
@@ -353,9 +353,10 @@ steps:
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
+ - pytest -v -s compile/test_fusion_all_reduce.py
- label: PyTorch Fullgraph Smoke Test # 9min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies:
- vllm/
@@ -368,7 +369,7 @@ steps:
- pytest -v -s compile/piecewise/test_full_cudagraph.py
- label: PyTorch Fullgraph Test # 18min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies:
- vllm/
@@ -377,7 +378,7 @@ steps:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Core Operation Test
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/
- tests/kernels/core
@@ -416,7 +417,7 @@ steps:
parallelism: 2
- label: Kernels Mamba Test
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/mamba/
- tests/kernels/mamba
@@ -424,7 +425,7 @@ steps:
- pytest -v -s kernels/mamba
- label: Tensorizer Test # 11min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
soft_fail: true
source_file_dependencies:
- vllm/model_executor/model_loader
@@ -437,7 +438,7 @@ steps:
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
- label: Model Executor Test
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/model_executor
- tests/model_executor
@@ -447,7 +448,7 @@ steps:
- pytest -v -s model_executor
- label: Benchmarks # 9min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/.buildkite"
source_file_dependencies:
- benchmarks/
@@ -455,7 +456,7 @@ steps:
- bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test # 10min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/benchmarks/
@@ -494,7 +495,7 @@ steps:
- pytest -s entrypoints/openai/correctness/
- label: Encoder Decoder tests # 5min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/encoder_decoder
@@ -502,7 +503,7 @@ steps:
- pytest -v -s encoder_decoder
- label: OpenAI-Compatible Tool Use # 20 min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
fast_check: false
source_file_dependencies:
- vllm/
@@ -623,7 +624,7 @@ steps:
# This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
optional: true
commands:
- echo 'Testing custom models...'
@@ -658,7 +659,7 @@ steps:
##### multi gpus test #####
- label: Distributed Comm Ops Test # 7min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
@@ -755,7 +756,7 @@ steps:
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Multi-step Tests (4 GPUs) # 36min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@@ -776,7 +777,7 @@ steps:
- pytest -v -s multi_step/test_correctness_llm.py
- label: Pipeline Parallelism Test # 45min
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@@ -790,7 +791,7 @@ steps:
- pytest -v -s distributed/test_pipeline_parallel.py
- label: LoRA TP Test (Distributed)
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental]
num_gpus: 4
source_file_dependencies:
- vllm/lora
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 43522ef8fb8dd..69aeee67a4300 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -370,6 +370,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
fi
# Install vllm wheel first, so that torch etc will be installed.
+# !bang
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system dist/*.whl --verbose \
diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu
index b9fc9def88190..2190151369761 100644
--- a/docker/Dockerfile.tpu
+++ b/docker/Dockerfile.tpu
@@ -1,4 +1,4 @@
-ARG NIGHTLY_DATE="20250724"
+ARG NIGHTLY_DATE="20250730"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 5a9823bb6bae7..f5d9e3b22f2a6 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
+| `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py
index ce735f3b27dfe..184c30891eca7 100644
--- a/examples/offline_inference/spec_decode.py
+++ b/examples/offline_inference/spec_decode.py
@@ -13,6 +13,38 @@ except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
+QUESTION = "What is the content of each image?"
+IMAGE_URLS = [
+ "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
+ "https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
+]
+
+
+def get_custom_mm_prompts(num_prompts):
+ prompts = []
+ for url in IMAGE_URLS:
+ prompts.append(
+ [
+ {"type": "image_url", "image_url": {"url": url}},
+ {"type": "text", "text": QUESTION},
+ ]
+ )
+ if num_prompts > len(IMAGE_URLS):
+ prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
+
+ return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
+
+
def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
@@ -35,6 +67,7 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
+ parser.add_argument("--custom-mm-prompts", action="store_true")
return parser.parse_args()
@@ -44,14 +77,26 @@ def main():
model_dir = args.model_dir
if args.model_dir is None:
+ if args.custom_mm_prompts:
+ raise ValueError(
+ "custom_mm_prompts requires mm based models"
+ "default llama3.1-8b-instruct is not mm based"
+ "please specify model_dir to give a mm based model"
+ )
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
+ args.custom_skip_chat_template = True
- prompts = get_samples(args, tokenizer)
- # add_special_tokens is False to avoid adding bos twice when using chat templates
- prompt_ids = [
- tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
- ]
+ if not args.custom_mm_prompts:
+ prompts = get_samples(args, tokenizer)
+ # add_special_tokens is False to avoid adding bos twice
+ # when using chat templates
+ prompt_ids = [
+ tokenizer.encode(prompt.prompt, add_special_tokens=False)
+ for prompt in prompts
+ ]
+ else:
+ prompts = get_custom_mm_prompts(args.num_prompts)
if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
@@ -85,10 +130,17 @@ def main():
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
+ limit_mm_per_prompt={"image": 5},
+ disable_chunked_mm_input=True,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
- outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
+ if not args.custom_mm_prompts:
+ outputs = llm.generate(
+ prompt_token_ids=prompt_ids, sampling_params=sampling_params
+ )
+ else:
+ outputs = llm.chat(prompts, sampling_params=sampling_params)
# print the generated text
if args.print_output:
diff --git a/requirements/test.txt b/requirements/test.txt
index d45048aae5809..4aaca2afea266 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -22,9 +22,7 @@ aiohttp==3.10.11
aiohttp-cors==0.8.1
# via ray
aiosignal==1.3.1
- # via
- # aiohttp
- # ray
+ # via aiohttp
albucore==0.0.16
# via terratorch
albumentations==1.4.6
@@ -139,7 +137,7 @@ contourpy==1.3.0
# via matplotlib
cramjam==2.9.0
# via fastparquet
-cupy-cuda12x==13.3.0
+cupy-cuda12x==13.5.1
# via ray
cycler==0.12.1
# via matplotlib
@@ -226,7 +224,6 @@ frozenlist==1.5.0
# via
# aiohttp
# aiosignal
- # ray
fsspec==2024.9.0
# via
# datasets
@@ -603,10 +600,18 @@ opencv-python-headless==4.11.0.86
opentelemetry-api==1.35.0
# via
# mlflow-skinny
+ # opentelemetry-exporter-prometheus
# opentelemetry-sdk
# opentelemetry-semantic-conventions
+opentelemetry-exporter-prometheus==0.56b0
+ # via ray
+opentelemetry-proto==1.36.0
+ # via ray
opentelemetry-sdk==1.35.0
- # via mlflow-skinny
+ # via
+ # mlflow-skinny
+ # opentelemetry-exporter-prometheus
+ # ray
opentelemetry-semantic-conventions==0.56b0
# via opentelemetry-sdk
packaging==24.2
@@ -697,7 +702,9 @@ pqdm==0.2.0
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0
- # via ray
+ # via
+ # opentelemetry-exporter-prometheus
+ # ray
propcache==0.2.0
# via yarl
proto-plus==1.26.1
@@ -707,6 +714,7 @@ protobuf==5.28.3
# google-api-core
# googleapis-common-protos
# mlflow-skinny
+ # opentelemetry-proto
# proto-plus
# ray
# tensorboardx
@@ -854,7 +862,7 @@ rasterio==1.4.3
# rioxarray
# terratorch
# torchgeo
-ray==2.43.0
+ray==2.48.0
# via -r requirements/test.in
redis==5.2.0
# via tensorizer
diff --git a/requirements/tpu.txt b/requirements/tpu.txt
index 2d0d8bd8457e3..7bb77c4a99636 100644
--- a/requirements/tpu.txt
+++ b/requirements/tpu.txt
@@ -19,8 +19,8 @@ nixl==0.3.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
-torch==2.9.0.dev20250724
-torchvision==0.24.0.dev20250724
-torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
-torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
+torch==2.9.0.dev20250730
+torchvision==0.24.0.dev20250730
+torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
+torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
diff --git a/setup.py b/setup.py
index bf3391e2db19e..6d615d122d69e 100644
--- a/setup.py
+++ b/setup.py
@@ -282,10 +282,69 @@ class cmake_build_ext(build_ext):
self.copy_file(file, dst_file)
-class repackage_wheel(build_ext):
+class precompiled_wheel_utils:
"""Extracts libraries and other files from an existing wheel."""
- def get_base_commit_in_main_branch(self) -> str:
+ @staticmethod
+ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict:
+ import tempfile
+ import zipfile
+
+ temp_dir = None
+ try:
+ if not os.path.isfile(wheel_url_or_path):
+ wheel_filename = wheel_url_or_path.split("/")[-1]
+ temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
+ wheel_path = os.path.join(temp_dir, wheel_filename)
+ print(f"Downloading wheel from {wheel_url_or_path} "
+ f"to {wheel_path}")
+ from urllib.request import urlretrieve
+ urlretrieve(wheel_url_or_path, filename=wheel_path)
+ else:
+ wheel_path = wheel_url_or_path
+ print(f"Using existing wheel at {wheel_path}")
+
+ package_data_patch = {}
+
+ with zipfile.ZipFile(wheel_path) as wheel:
+ files_to_copy = [
+ "vllm/_C.abi3.so",
+ "vllm/_moe_C.abi3.so",
+ "vllm/_flashmla_C.abi3.so",
+ "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
+ "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
+ "vllm/cumem_allocator.abi3.so",
+ ]
+
+ compiled_regex = re.compile(
+ r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
+ file_members = list(
+ filter(lambda x: x.filename in files_to_copy,
+ wheel.filelist))
+ file_members += list(
+ filter(lambda x: compiled_regex.match(x.filename),
+ wheel.filelist))
+
+ for file in file_members:
+ print(f"[extract] {file.filename}")
+ target_path = os.path.join(".", file.filename)
+ os.makedirs(os.path.dirname(target_path), exist_ok=True)
+ with wheel.open(file.filename) as src, open(
+ target_path, "wb") as dst:
+ shutil.copyfileobj(src, dst)
+
+ pkg = os.path.dirname(file.filename).replace("/", ".")
+ package_data_patch.setdefault(pkg, []).append(
+ os.path.basename(file.filename))
+
+ return package_data_patch
+ finally:
+ if temp_dir is not None:
+ print(f"Removing temporary directory {temp_dir}")
+ shutil.rmtree(temp_dir)
+
+ @staticmethod
+ def get_base_commit_in_main_branch() -> str:
# Force to use the nightly wheel. This is mainly used for CI testing.
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
return "nightly"
@@ -334,115 +393,6 @@ class repackage_wheel(build_ext):
"wheel may not be compatible with your dev branch: %s", err)
return "nightly"
- def run(self) -> None:
- assert _is_cuda(
- ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
-
- wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
- if wheel_location is None:
- base_commit = self.get_base_commit_in_main_branch()
- wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
- # Fallback to nightly wheel if latest commit wheel is unavailable,
- # in this rare case, the nightly release CI hasn't finished on main.
- if not is_url_available(wheel_location):
- wheel_location = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
-
- import zipfile
-
- if os.path.isfile(wheel_location):
- wheel_path = wheel_location
- print(f"Using existing wheel={wheel_path}")
- else:
- # Download the wheel from a given URL, assume
- # the filename is the last part of the URL
- wheel_filename = wheel_location.split("/")[-1]
-
- import tempfile
-
- # create a temporary directory to store the wheel
- temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
- wheel_path = os.path.join(temp_dir, wheel_filename)
- print(f"Downloading wheel from {wheel_location} to {wheel_path}")
- from urllib.request import urlretrieve
- try:
- urlretrieve(wheel_location, filename=wheel_path)
- except Exception as e:
- from setuptools.errors import SetupError
- raise SetupError(
- f"Failed to get vLLM wheel from {wheel_location}") from e
-
- # Set the dist_dir for Docker build context
- dist_dir = ("/workspace/dist"
- if envs.VLLM_DOCKER_BUILD_CONTEXT else "dist")
- os.makedirs(dist_dir, exist_ok=True)
-
- # Extract only necessary compiled .so files from precompiled wheel
- with zipfile.ZipFile(wheel_path) as wheel:
- # Get version from METADATA (optional, mostly useful for logging)
- metadata_file = next((n for n in wheel.namelist()
- if n.endswith(".dist-info/METADATA")), None)
- if not metadata_file:
- raise RuntimeError(
- "Could not find METADATA in precompiled wheel.")
- metadata = wheel.read(metadata_file).decode()
- version_line = next((line for line in metadata.splitlines()
- if line.startswith("Version: ")), None)
- if not version_line:
- raise RuntimeError(
- "Could not determine version from METADATA.")
- version = version_line.split(": ")[1].strip()
-
- print(f"Extracting precompiled kernels from vLLM wheel version: "
- f"{version}")
-
- # List of compiled shared objects to extract
- files_to_copy = [
- "vllm/_C.abi3.so",
- "vllm/_moe_C.abi3.so",
- "vllm/_flashmla_C.abi3.so",
- "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
- "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
- "vllm/cumem_allocator.abi3.so",
- ]
-
- file_members = list(
- filter(lambda x: x.filename in files_to_copy, wheel.filelist))
- compiled_regex = re.compile(
- r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
- file_members += list(
- filter(lambda x: compiled_regex.match(x.filename),
- wheel.filelist))
-
- for file in file_members:
- print(f"Extracting and including {file.filename} "
- "from existing wheel")
- package_name = os.path.dirname(file.filename).replace("/", ".")
- file_name = os.path.basename(file.filename)
-
- if package_name not in package_data:
- package_data[package_name] = []
-
- output_base = (dist_dir
- if envs.VLLM_DOCKER_BUILD_CONTEXT else ".")
- target_path = os.path.join(output_base, file.filename)
- os.makedirs(os.path.dirname(target_path), exist_ok=True)
- with wheel.open(file.filename) as src, open(target_path,
- "wb") as dst:
- shutil.copyfileobj(src, dst)
-
- package_data[package_name].append(file_name)
-
- # Copy wheel into dist dir for Docker to consume (e.g., via --mount)
- if envs.VLLM_DOCKER_BUILD_CONTEXT:
- arch_tag = "cp38-abi3-manylinux1_x86_64"
- corrected_wheel_name = f"vllm-{version}-{arch_tag}.whl"
- final_wheel_path = os.path.join(dist_dir, corrected_wheel_name)
-
- print(
- "Docker build context detected, copying precompiled wheel to "
- f"{final_wheel_path}")
- shutil.copy2(wheel_path, final_wheel_path)
-
def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty"
@@ -676,16 +626,37 @@ package_data = {
]
}
+# If using precompiled, extract and patch package_data (in advance of setup)
+if envs.VLLM_USE_PRECOMPILED:
+ assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
+ wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
+ if wheel_location is not None:
+ wheel_url = wheel_location
+ else:
+ base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
+ wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
+ from urllib.request import urlopen
+ try:
+ with urlopen(wheel_url) as resp:
+ if resp.status != 200:
+ wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
+ except Exception as e:
+ print(f"[warn] Falling back to nightly wheel: {e}")
+ wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
+
+ patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
+ wheel_url)
+ for pkg, files in patch.items():
+ package_data.setdefault(pkg, []).extend(files)
+
if _no_device():
ext_modules = []
-if not ext_modules:
+if not ext_modules or envs.VLLM_USE_PRECOMPILED:
+ # Disable build_ext when using precompiled wheel
cmdclass = {}
else:
- cmdclass = {
- "build_ext":
- repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
- }
+ cmdclass = {"build_ext": cmake_build_ext}
setup(
# static metadata should rather go in pyproject.toml
diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py
index b8d64247f6beb..b394e0035c689 100644
--- a/tests/compile/test_fusion_all_reduce.py
+++ b/tests/compile/test_fusion_all_reduce.py
@@ -7,22 +7,26 @@ import torch
import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
+from vllm.compilation.fix_functionalization import FixFunctionalizationPass
+from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
+ GroupShape, QuantFP8)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
-from ..utils import multi_gpu_test
+from ..utils import has_module_attribute, multi_gpu_test
from .backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module):
- def __init__(self, hidden_size=16, eps=1e-6):
+ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
- def __init__(self, hidden_size=16, eps=1e-6):
+ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
+class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
+
+ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.eps = eps
+ self.norm = RMSNorm(hidden_size, eps)
+ self.quant_fp8 = QuantFP8(static=True,
+ group_shape=GroupShape.PER_TENSOR)
+ self.scale = torch.rand(1, dtype=torch.float32)
+ self.output = torch.empty((token_num, hidden_size),
+ dtype=torch.float32)
+
+ def forward(self, hidden_states, residual):
+ view = hidden_states.reshape(-1, self.hidden_size)
+ all_reduce = tensor_model_parallel_all_reduce(view)
+ norm_output, residual_output = self.norm(all_reduce, residual)
+ torch.ops._C.static_scaled_fp8_quant(self.output,
+ norm_output.contiguous(),
+ self.scale)
+ return self.output, residual_output
+
+ def ops_in_model_after(self):
+ return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
+
+ def ops_in_model_before(self):
+ return [
+ torch.ops.vllm.all_reduce.default,
+ torch.ops._C.static_scaled_fp8_quant.default
+ ]
+
+
+class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
+
+ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.eps = eps
+ self.norm = RMSNorm(hidden_size, eps)
+ self.scale = torch.rand(1, dtype=torch.float32)
+ self.output = torch.empty((token_num, hidden_size),
+ dtype=torch.float32)
+
+ round_up = lambda x, y: (x + y - 1) // y * y
+ rounded_m = round_up(token_num, 128)
+ scale_n = hidden_size // 16
+ rounded_n = round_up(scale_n, 4)
+ self.output_scale = torch.empty((rounded_m, rounded_n // 4),
+ dtype=torch.int32)
+
+ def forward(self, hidden_states, residual):
+ view = hidden_states.reshape(-1, self.hidden_size)
+ all_reduce = tensor_model_parallel_all_reduce(view)
+ norm_output, residual_output = self.norm(all_reduce, residual)
+ norm_output = norm_output.reshape(-1, norm_output.shape[-1])
+ torch.ops._C.scaled_fp4_quant(self.output, norm_output,
+ self.output_scale, self.scale)
+ return self.output, residual_output, self.output_scale
+
+ def ops_in_model_after(self):
+ return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
+
+ def ops_in_model_before(self):
+ return [
+ torch.ops.vllm.all_reduce.default,
+ torch.ops._C.scaled_fp4_quant.default
+ ]
+
+
@multi_gpu_test(num_gpus=2)
-@pytest.mark.parametrize(
- "test_model",
- [TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
+@pytest.mark.parametrize("test_model", [
+ TestAllReduceRMSNormModel,
+ TestAllReduceFusedAddRMSNormModel,
+ TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
+ TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
+])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
-@pytest.mark.parametrize("hidden_size", [4096])
+@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
-@pytest.mark.skipif(not find_spec("flashinfer"),
- reason="flashinfer is not installed")
-@pytest.mark.skipif(not current_platform.is_device_capability(100),
- reason="Only test on SM100")
+@pytest.mark.skipif(
+ not find_spec("flashinfer")
+ or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
+ reason="flashinfer is not found or flashinfer "
+ "is not compiled with trtllm_allreduce_fusion")
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2
+ if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
+ and not current_platform.has_device_capability(100)):
+ pytest.skip("Skip as nvfp4 is only supported on "
+ "devices with compute capability 10.0 (Blackwell)")
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
@@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
- vllm_config = VllmConfig(
- compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
- custom_ops=["+rms_norm"],
- compile_sizes=[2, 4, 8]))
+ vllm_config = VllmConfig(compilation_config=CompilationConfig(
+ level=CompilationLevel.PIECEWISE,
+ custom_ops=["+rms_norm", "+quant_fp8"]))
vllm_config.compilation_config.pass_config = PassConfig(
- enable_fi_allreduce_fusion=True)
+ enable_fi_allreduce_fusion=True, enable_noop=True)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
@@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
seed=42)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
- backend = TestBackend(all_reduce_fusion_pass)
+ noop_pass = NoOpEliminationPass(vllm_config)
+ func_pass = FixFunctionalizationPass(vllm_config)
- model = test_model_cls(hidden_size)
+ backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
- hidden_states = torch.randn((batch_size * seq_len, hidden_size),
- requires_grad=False)
- residual = torch.randn((batch_size * seq_len, hidden_size),
- requires_grad=False)
+ token_num = batch_size * seq_len
+ model = test_model_cls(hidden_size, token_num)
+
+ hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
+ residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 8fcff5a8c5113..b9e7de4e9fd11 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
+ "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
+ trust_remote_code=True,
+ is_available_online=False),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
trust_remote_code=True),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
@@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
+ "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
+ trust_remote_code=True,
+ is_available_online=False),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
diff --git a/tests/utils.py b/tests/utils.py
index f4317e6bdb406..1c1a1cc6014ec 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -4,6 +4,7 @@
import asyncio
import copy
import functools
+import importlib
import os
import signal
import subprocess
@@ -974,3 +975,14 @@ def get_client_text_logprob_generations(
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for completion in completions for x in completion.choices]
+
+
+def has_module_attribute(module_name, attribute_name):
+ """
+ Helper function to check if a module has a specific attribute.
+ """
+ try:
+ module = importlib.import_module(module_name)
+ return hasattr(module, attribute_name)
+ except ImportError:
+ return False
diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py
new file mode 100644
index 0000000000000..8c5a63653db9f
--- /dev/null
+++ b/tests/v1/attention/test_chunked_local_attention.py
@@ -0,0 +1,196 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from dataclasses import dataclass
+
+import numpy as np
+import pytest
+import torch
+
+from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
+from vllm.v1.attention.backends.utils import (
+ make_local_attention_virtual_batches)
+
+
+@dataclass
+class LocalAttentionTestData:
+ # Input parameters
+ batch_spec: BatchSpec
+ attn_chunk_size: int
+ block_size: int
+ # Expected return values
+ expected_q_seqlens: list[int]
+ expected_k_seqlens: list[int]
+ expected_local_block_table: list[list[int]]
+
+
+test_data_list = [
+ # Same as example in docstring of make_local_attention_virtual_batches
+ # except block table has 9 columns instead of 10
+ LocalAttentionTestData(
+ batch_spec=BatchSpec(
+ query_lens=[4, 10, 5],
+ seq_lens=[6, 17, 9],
+ ),
+ attn_chunk_size=4,
+ block_size=2,
+ expected_q_seqlens=[2, 2, 1, 4, 4, 1, 4, 1],
+ expected_k_seqlens=[4, 2, 4, 4, 4, 1, 4, 1],
+ # 2 pages per local branch
+ # (chunk size 4 // block size 2)
+ expected_local_block_table=[
+ [0, 1], # local-batch 0, (batch 0, starting from k[0])
+ [2, 3], # local-batch 1, (batch 0, starting from k[4])
+ [11, 12], # local-batch 2, (batch 1, starting from k[4])
+ [13, 14], # local-batch 3, (batch 1, starting from k[8])
+ [15, 16], # local-batch 4, (batch 1, starting from k[12])
+ [17, 17], # local-batch 5, (batch 1, starting from k[16])
+ [20, 21], # local-batch 6, (batch 2, starting from k[4])
+ [22, 23], # local-batch 7, (batch 2, starting from k[8])
+ ]),
+ # Case where block indices are not clipped to block table ncols-1
+ # because tokens_in_last_block == attn_chunk_size
+ LocalAttentionTestData(batch_spec=BatchSpec(
+ query_lens=[8],
+ seq_lens=[12],
+ ),
+ attn_chunk_size=4,
+ block_size=2,
+ expected_q_seqlens=[4, 4],
+ expected_k_seqlens=[4, 4],
+ expected_local_block_table=[
+ [2, 3],
+ [4, 5],
+ ]),
+ # Case where all kv_seq positions are involved in attn
+ LocalAttentionTestData(
+ batch_spec=BatchSpec(
+ query_lens=[7],
+ # 10 - 7 = 3 previously computed tokens
+ seq_lens=[10],
+ ),
+ attn_chunk_size=4,
+ block_size=2,
+ expected_q_seqlens=[1, 4, 2],
+ expected_k_seqlens=[4, 4, 2],
+ expected_local_block_table=[
+ [0, 1],
+ [2, 3],
+ [4, 4],
+ ]),
+ # Case where attn_chunk_size > kv_seq_len
+ # so no extra mini virtual batches are created
+ LocalAttentionTestData(
+ batch_spec=BatchSpec(
+ query_lens=[4],
+ seq_lens=[6],
+ ),
+ # Larger than kv_seq_len
+ attn_chunk_size=10,
+ block_size=2,
+ # No change to q_seqlens and k_seqlens
+ expected_q_seqlens=[4],
+ expected_k_seqlens=[6],
+ # In this case, we only need a block-table like:
+ # block_table = [ [0, 1, 2] ] # 1 batch, 3 pages
+ # But we need to pad it to 5 pages per local batch
+ # because currently the pages_per_local_batch
+ # is calculated as (attn_chunk_size // block_size)
+ expected_local_block_table=[
+ [0, 1, 2, 2, 2],
+ ]),
+ # Block size equal to chunk size
+ # Expect single page per batch in local batch table
+ LocalAttentionTestData(
+ batch_spec=BatchSpec(
+ query_lens=[6, 6],
+ seq_lens=[8, 8],
+ ),
+ attn_chunk_size=4,
+ block_size=4,
+ expected_q_seqlens=[2, 4, 2, 4],
+ expected_k_seqlens=[4, 4, 4, 4],
+ # Initial block table = [
+ # [0, 1], < batch 0
+ # [2, 3], < batch 1
+ # ]
+ expected_local_block_table=[
+ [0], # local-batch 0, (batch 0, starting from k[0])
+ [1], # local-batch 1, (batch 0, starting from k[4])
+ [2], # local-batch 1, (batch 0, starting from k[0])
+ [3], # local-batch 1, (batch 0, starting from k[4])
+ ]),
+ # Case where query falls in the second attention chunk
+ # k_toks > 0 1 2 3 4
+ # q_toks v _____________
+ # 0 | 1
+ # 1 | 1 1
+ # 2 | 1 1 1
+ # 3 | 1 1 1 1
+ # 4 | 1
+ # where tokens 0,1,2,3 have been pre-computed
+ LocalAttentionTestData(batch_spec=BatchSpec(
+ query_lens=[1],
+ seq_lens=[5],
+ ),
+ attn_chunk_size=4,
+ block_size=2,
+ expected_q_seqlens=[1],
+ expected_k_seqlens=[1],
+ expected_local_block_table=[
+ [2, 2],
+ ]),
+]
+
+
+@pytest.mark.parametrize("test_data", test_data_list)
+def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
+ device = torch.device("cuda:0")
+ batch_spec = test_data.batch_spec
+ attn_chunk_size = test_data.attn_chunk_size
+ block_size = test_data.block_size
+ expected_q_seqlens = test_data.expected_q_seqlens
+ expected_k_seqlens = test_data.expected_k_seqlens
+ expected_local_block_table = test_data.expected_local_block_table
+
+ # Create common attention metadata
+ common_attn_metadata = create_common_attn_metadata(
+ batch_spec,
+ block_size,
+ device,
+ # Use torch.arange instead of torch.randint so we can assert on
+ # block table tensor values. The block table will have shape
+ # (num_batches, cdiv(max_seq_len, block_size)) and the values will be
+ # aranged from 0 to cdiv(max_seq_len, block_size)-1
+ arange_block_indices=True,
+ )
+
+ # Call the function
+ result = make_local_attention_virtual_batches(attn_chunk_size,
+ common_attn_metadata,
+ block_size)
+
+ # Convert to numpy for easier comparison
+ actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy())
+ actual_k_seqlens = result.seq_lens_cpu.numpy()
+
+ # Check that all query lengths are less than or equal to attn_chunk_size
+ assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens)
+ # Check that all key lengths are less than or equal to attn_chunk_size
+ assert all(k_len <= attn_chunk_size for k_len in actual_k_seqlens)
+ # Check that the total number of query tokens is preserved
+ assert sum(actual_q_seqlens) == sum(batch_spec.query_lens)
+
+ # Verify results
+ np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens)
+ np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens)
+
+ expected_block_table_tensor =\
+ torch.tensor(expected_local_block_table,
+ dtype=torch.int32,
+ device=device)
+
+ print(f"Expected block table:\n{expected_block_table_tensor}")
+ print(f"Actual block table:\n{result.block_table_tensor}")
+
+ torch.testing.assert_close(result.block_table_tensor,
+ expected_block_table_tensor)
diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py
index ae2ab6e6413c0..be6cfce6fba8a 100644
--- a/tests/v1/attention/utils.py
+++ b/tests/v1/attention/utils.py
@@ -40,7 +40,8 @@ def create_common_attn_metadata(
batch_spec: BatchSpec,
block_size: int,
device: torch.device,
- max_block_idx: int = 1000) -> CommonAttentionMetadata:
+ max_block_idx: int = 1000,
+ arange_block_indices: bool = False) -> CommonAttentionMetadata:
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
# Create query start locations
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
@@ -65,19 +66,28 @@ def create_common_attn_metadata(
]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
- # Create block table (random for testing)
+ # Create block table and slot mapping
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
- block_table_tensor = torch.randint(0,
- max_block_idx,
- (batch_spec.batch_size, max_blocks),
- dtype=torch.int32,
- device=device)
-
- # Create slot mapping
- slot_mapping = torch.randint(0,
- max_block_idx, (num_tokens, ),
- dtype=torch.int64,
- device=device)
+ if arange_block_indices:
+ num_blocks = batch_spec.batch_size * max_blocks
+ block_table_tensor = torch.arange(num_blocks,
+ dtype=torch.int32,
+ device=device).view(
+ batch_spec.batch_size,
+ max_blocks)
+ slot_mapping = torch.arange(num_tokens,
+ dtype=torch.int64,
+ device=device).view(num_tokens)
+ else:
+ block_table_tensor = torch.randint(0,
+ max_block_idx,
+ (batch_spec.batch_size, max_blocks),
+ dtype=torch.int32,
+ device=device)
+ slot_mapping = torch.randint(0,
+ max_block_idx, (num_tokens, ),
+ dtype=torch.int64,
+ device=device)
# Calculate max query length
max_query_len = max(batch_spec.query_lens)
diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py
index 2423f966acfab..31f25e94c5b4b 100644
--- a/tests/v1/e2e/test_spec_decode.py
+++ b/tests/v1/e2e/test_spec_decode.py
@@ -3,29 +3,34 @@
from __future__ import annotations
import random
-from typing import Any
+from typing import Any, Union
import pytest
import torch
from vllm import LLM, SamplingParams
+from vllm.assets.base import VLLM_S3_BUCKET_URL
+from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
-@pytest.fixture
-def test_prompts():
+def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
+ if mm_enabled:
+ prompt_types.append("mm")
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
+ print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
+ prompt: Union[str, list[dict[str, Any]]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
@@ -38,6 +43,21 @@ def test_prompts():
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
+ elif kind == "mm":
+ placeholders = [{
+ "type": "image_url",
+ "image_url": {
+ "url":
+ f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
+ },
+ }]
+ prompt = [
+ *placeholders,
+ {
+ "type": "text",
+ "text": "The meaning of the image is"
+ },
+ ]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
@@ -57,7 +77,6 @@ def model_name():
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
- test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
@@ -67,6 +86,7 @@ def test_ngram_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
+ test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@@ -103,23 +123,32 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()
-@pytest.mark.parametrize("model_setup", [
- ("eagle", "meta-llama/Llama-3.1-8B-Instruct",
- "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
- ("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
- "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
- pytest.param(
- ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
- "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
- marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
-],
- ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
+@pytest.mark.parametrize(
+ ["model_setup", "mm_enabled"], [
+ (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
+ "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
+ (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
+ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
+ pytest.param(
+ ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
+ "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
+ False,
+ marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
+ pytest.param(
+ ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
+ "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
+ True,
+ marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
+ ],
+ ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
- test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
+ mm_enabled: bool,
):
+ # Generate test prompts inside the function instead of using fixture
+ test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py
index cb99fe8310e73..6ae50245ed3a8 100644
--- a/vllm/compilation/collective_fusion.py
+++ b/vllm/compilation/collective_fusion.py
@@ -37,6 +37,8 @@ logger = init_logger(__name__)
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
+STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
+STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
class BasePattern:
@@ -394,7 +396,7 @@ if flashinfer_comm is not None:
# Max size of the input tensor per world size
# to use flashinfer fused allreduce
_FI_MAX_SIZES = {
- 2: MiB, # 1MB
+ 2: 64 * MiB, # 64MB
4: MiB, # 1MB
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
@@ -414,9 +416,13 @@ if flashinfer_comm is not None:
trigger_completion_at_end: bool,
fp32_acc: bool,
max_token_num: int,
+ pattern_code: int,
+ fuse_rms_quant: bool,
norm_out: Optional[torch.Tensor] = None,
+ quant_out: Optional[torch.Tensor] = None,
+ scale_out: Optional[torch.Tensor] = None,
+ scale_factor: Optional[torch.Tensor] = None,
) -> None:
-
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
@@ -425,7 +431,6 @@ if flashinfer_comm is not None:
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
-
if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer"
@@ -455,37 +460,65 @@ if flashinfer_comm is not None:
use_oneshot=True,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
- pattern_code=flashinfer_comm.AllReduceFusionPattern.
- kARResidualRMSNorm,
+ pattern_code=pattern_code,
allreduce_out=None,
- quant_out=None,
- scale_out=None,
- layout_code=None,
- scale_factor=None,
+ quant_out=quant_out,
+ scale_out=scale_out,
+ # in vllm we only support swizzled layout
+ layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
+ scale_factor=scale_factor,
)
else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
- if norm_out is None:
- torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
- rms_gamma, rms_eps)
+ if (scale_factor is not None and scale_out is None
+ and fuse_rms_quant):
+ # Do fused rms norm static fp8 quant fused op
+ if norm_out is None:
+ torch.ops._C.fused_add_rms_norm_static_fp8_quant(
+ quant_out, allreduce_out, residual, rms_gamma,
+ scale_factor, rms_eps)
+ else:
+ torch.ops._C.rms_norm_static_fp8_quant(
+ quant_out, allreduce_out, rms_gamma, scale_factor,
+ rms_eps)
else:
- torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
- rms_eps)
- allreduce_in.copy_(allreduce_out)
+ if norm_out is None:
+ torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
+ rms_gamma, rms_eps)
+ norm_out = allreduce_out
+ else:
+ torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
+ rms_eps)
+ if scale_factor is not None:
+ if scale_out is not None:
+ torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
+ scale_out, scale_factor)
+ else:
+ torch.ops._C.static_scaled_fp8_quant(
+ quant_out, norm_out, scale_factor)
+ if scale_factor is None or norm_out is not None:
+ # we need to return allreduce outpput
+ # in cases of non quant fused AR + RMS norm
+ # and fused AR + RMS norm + quant without fused add
+ allreduce_in.copy_(allreduce_out)
def call_trtllm_fused_allreduce_norm_fake(
- allreduce_in: torch.Tensor,
- residual: torch.Tensor,
- rms_gamma: torch.Tensor,
- rms_eps: float,
- world_rank: int,
- world_size: int,
- launch_with_pdl: bool,
- trigger_completion_at_end: bool,
- fp32_acc: bool,
- max_token_num: int,
- norm_out: Optional[torch.Tensor] = None,
- ) -> None:
+ allreduce_in: torch.Tensor,
+ residual: torch.Tensor,
+ rms_gamma: torch.Tensor,
+ rms_eps: float,
+ world_rank: int,
+ world_size: int,
+ launch_with_pdl: bool,
+ trigger_completion_at_end: bool,
+ fp32_acc: bool,
+ max_token_num: int,
+ pattern_code: int,
+ fuse_rms_quant: bool,
+ norm_out: Optional[torch.Tensor] = None,
+ quant_out: Optional[torch.Tensor] = None,
+ scale_out: Optional[torch.Tensor] = None,
+ scale_factor: Optional[torch.Tensor] = None) -> None:
pass
direct_register_custom_op(
@@ -495,6 +528,8 @@ if flashinfer_comm is not None:
"allreduce_in",
"residual",
"norm_out",
+ "quant_out",
+ "scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
@@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams:
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
+ fuse_rms_quant: bool = False,
):
self.rank = rank
self.world_size = world_size
@@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams:
self.fp32_acc = True
self.use_oneshot = False
self.max_token_num = max_token_num
+ self.fuse_rms_quant = fuse_rms_quant
def get_trtllm_fused_allreduce_kwargs(self):
return {
@@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams:
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
+ "fuse_rms_quant": self.fuse_rms_quant,
}
-class AllReduceRMSNORMPattern(BasePattern):
+class AllReduceRMSNormPattern(BasePattern):
+ """
+ This pattern replaces the allreduce + rms norm (without residual)
+ with fused flashinfer implementation.
+ Applies to allreduce + rmsnorm before attn in the first Transformer block.
+ """
def __init__(
self,
@@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern):
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor):
- all_reduce_output = tensor_model_parallel_all_reduce(input)
+ allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_OP,
result=rms_result,
- input=all_reduce_output,
+ input=allreduce_output,
weight=weight,
epsilon=self.epsilon,
)
- return rms[1], all_reduce_output
+ # rms_result, allreduce_output
+ return rms[1], allreduce_output
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor):
residual = torch.zeros_like(input)
allreduce = auto_functionalized(
- torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
+ flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
+ quant_out=None,
+ scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
+ pattern_code=flashinfer_comm.AllReduceFusionPattern.
+ kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
-
+ # rms_result, allreduce_in
return allreduce[3], allreduce[1]
pm.register_replacement(pattern, replacement, self.get_inputs(),
@@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern):
class AllReduceFusedAddRMSNormPattern(BasePattern):
+ """
+ This pattern replaces the allreduce + rms norm (with residual)
+ with fused flashinfer implementation.
+ Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
+ """
def __init__(
self,
@@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
def pattern(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor):
- all_reduce_output = tensor_model_parallel_all_reduce(input)
+ allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_ADD_OP,
- input=all_reduce_output,
+ input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
+ # input, residual
return rms[1], rms[2]
def replacement(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor):
allreduce = auto_functionalized(
- torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
+ flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
+ norm_out=None,
+ quant_out=None,
+ scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
- norm_out=None,
+ pattern_code=flashinfer_comm.AllReduceFusionPattern.
+ kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
+ # allreduce_in, residual
return allreduce[1], allreduce[2]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
+class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
+ """
+ This pattern replaces the allreduce + rms norm (without residual)
+ + static fp8 quant with fused flashinfer implementation.
+ Applies to allreduce + rmsnorm + quant before attn
+ in the first Transformer block.
+ """
+
+ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
+ allreduce_params: FlashInferFusedAllReduceParams):
+ super().__init__(dtype, device)
+ self.epsilon = epsilon
+ self.allreduce_params = allreduce_params
+ self.quant_dtype = torch.float8_e4m3fn
+
+ def register(self, pm_pass: PatternMatcherPass):
+
+ def get_inputs():
+ input = torch.zeros([1, 8, 4],
+ device=self.device,
+ dtype=self.dtype)
+ rmsnorm_result = torch.empty([1, 8, 4],
+ device=self.device,
+ dtype=self.dtype)
+ quant_result = torch.empty([1, 8, 4],
+ device=self.device,
+ dtype=self.quant_dtype)
+ weight = torch.empty([4], device=self.device, dtype=self.dtype)
+ scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
+ return [input, rmsnorm_result, quant_result, weight, scale]
+
+ def pattern(
+ input: torch.Tensor,
+ rmsnorm_result: torch.Tensor,
+ quant_result: torch.Tensor,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ ):
+ all_reduce = tensor_model_parallel_all_reduce(input)
+ rmsnorm_out_tuple = auto_functionalized(RMS_OP,
+ result=rmsnorm_result,
+ input=all_reduce,
+ weight=weight,
+ epsilon=self.epsilon)
+
+ quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP,
+ result=quant_result,
+ input=rmsnorm_out_tuple[1],
+ scale=scale)
+
+ # quant_out, allreduce_output
+ return quant_out_tuple[1], all_reduce
+
+ def replacement(input: torch.Tensor, result_rms: torch.Tensor,
+ quant_result: torch.Tensor, weight: torch.Tensor,
+ scale: torch.Tensor):
+ residual = torch.zeros_like(input)
+ allreduce = auto_functionalized(
+ flashinfer_trtllm_fused_allreduce_norm,
+ allreduce_in=input,
+ residual=residual,
+ norm_out=result_rms,
+ quant_out=quant_result,
+ scale_out=None,
+ rms_gamma=weight,
+ rms_eps=self.epsilon,
+ pattern_code=flashinfer_comm.AllReduceFusionPattern.
+ kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
+ scale_factor=scale,
+ **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
+ )
+
+ # quant_out, allreduce_output
+ return allreduce[4], allreduce[1]
+
+ pm.register_replacement(pattern, replacement, get_inputs(),
+ pm.fwd_only, pm_pass)
+
+
+class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
+ """
+ This pattern replaces the allreduce + rms norm (with residual)
+ + static fp8 quant with fused flashinfer implementation.
+ Applies to o_proj + rmsnorm after attn + quant and
+ mlp + rmsnorm + quant before attn.
+ """
+
+ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
+ allreduce_params: FlashInferFusedAllReduceParams):
+ super().__init__(dtype, device)
+ self.epsilon = epsilon
+ self.allreduce_params = allreduce_params
+ self.quant_dtype = torch.float8_e4m3fn
+
+ def register(self, pm_pass: PatternMatcherPass):
+
+ def get_inputs():
+ input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
+
+ residual = torch.empty([4, 4],
+ device=self.device,
+ dtype=self.dtype)
+ weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
+ quant_result = torch.empty([4, 4],
+ device=self.device,
+ dtype=self.quant_dtype)
+ scale = torch.empty([1, 1],
+ device=self.device,
+ dtype=torch.float32)
+
+ return [
+ quant_result,
+ residual,
+ input,
+ weight,
+ scale,
+ ]
+
+ def pattern(
+ quant_result: torch.Tensor,
+ residual: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ ):
+ allreduce_output = tensor_model_parallel_all_reduce(input)
+
+ fused_add_rmsnorm_out_tuple = \
+ auto_functionalized(
+ RMS_ADD_OP,
+ input=allreduce_output,
+ residual=residual,
+ weight=weight,
+ epsilon=self.epsilon)
+ quant_out_tuple = auto_functionalized(
+ STATIC_FP8_QUANT_OP,
+ result=quant_result,
+ input=fused_add_rmsnorm_out_tuple[1],
+ scale=scale)
+
+ # quant_out, allreduce_output
+ return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
+
+ def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
+ input: torch.Tensor, weight: torch.Tensor,
+ scale: torch.Tensor):
+ allreduce = auto_functionalized(
+ flashinfer_trtllm_fused_allreduce_norm,
+ allreduce_in=input,
+ residual=residual,
+ norm_out=None,
+ quant_out=quant_result,
+ scale_out=None,
+ rms_gamma=weight,
+ rms_eps=self.epsilon,
+ pattern_code=flashinfer_comm.AllReduceFusionPattern.
+ kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
+ scale_factor=scale,
+ **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
+ )
+ # # quant_out, rms_norm_residual
+ return allreduce[4], allreduce[2]
+
+ pm.register_replacement(pattern, replacement, get_inputs(),
+ pm.fwd_only, pm_pass)
+
+
+class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
+ """
+ This pattern replaces the allreduce + rms norm (without residual)
+ + static nvfp4 quant with fused flashinfer implementation.
+ Applies to allreduce + rmsnorm + quant before attn
+ in the first Transformer block.
+ """
+
+ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
+ allreduce_params: FlashInferFusedAllReduceParams):
+ super().__init__(dtype, device)
+ self.epsilon = epsilon
+ self.allreduce_params = allreduce_params
+
+ def register(self, pm_pass: PatternMatcherPass):
+
+ def get_inputs():
+ input = torch.empty([1, 16, 16],
+ device=self.device,
+ dtype=self.dtype)
+
+ rmsnorm_result = torch.empty([1, 16, 16],
+ device=self.device,
+ dtype=self.dtype)
+ quant_result = torch.empty((16, 8),
+ device=self.device,
+ dtype=torch.uint8)
+ input_global_scale = torch.empty([1, 1],
+ device=self.device,
+ dtype=torch.float32)
+ weight = torch.empty([16], device=self.device, dtype=self.dtype)
+ output_scale = torch.empty([128, 4],
+ device=self.device,
+ dtype=torch.int32)
+
+ return [
+ input, rmsnorm_result, quant_result, weight,
+ input_global_scale, output_scale
+ ]
+
+ def pattern(
+ input: torch.Tensor,
+ rmsnorm_result: torch.Tensor,
+ quant_result: torch.Tensor,
+ weight: torch.Tensor,
+ input_global_scale: torch.Tensor,
+ output_scale: torch.Tensor,
+ ):
+ all_reduce = tensor_model_parallel_all_reduce(input)
+ rmsnorm_out_tuple = auto_functionalized(RMS_OP,
+ result=rmsnorm_result,
+ input=all_reduce,
+ weight=weight,
+ epsilon=self.epsilon)
+
+ quant_out_tuple = auto_functionalized(
+ STATIC_FP4_QUANT_OP,
+ output=quant_result,
+ input=rmsnorm_out_tuple[1],
+ output_scale=output_scale,
+ input_scale=input_global_scale)
+
+ # quant_out, allreduce_output, output_scale
+ return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
+
+ def replacement(input: torch.Tensor, result_rms: torch.Tensor,
+ quant_result: torch.Tensor, weight: torch.Tensor,
+ input_global_scale: torch.Tensor,
+ output_scale: torch.Tensor):
+ residual = torch.zeros_like(input)
+ allreduce = auto_functionalized(
+ flashinfer_trtllm_fused_allreduce_norm,
+ allreduce_in=input,
+ residual=residual,
+ norm_out=result_rms,
+ quant_out=quant_result,
+ scale_out=output_scale,
+ rms_gamma=weight,
+ rms_eps=self.epsilon,
+ pattern_code=flashinfer_comm.AllReduceFusionPattern.
+ kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
+ scale_factor=input_global_scale,
+ **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
+ )
+
+ # quant_out, allreduce_output, output_scale
+ return allreduce[4], allreduce[1], allreduce[5]
+
+ pm.register_replacement(pattern, replacement, get_inputs(),
+ pm.fwd_only, pm_pass)
+
+
+class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
+ """
+ This pattern replaces the allreduce + rms norm (with residual)
+ + static nvfp4 quant with fused flashinfer implementation.
+ Applies to o_proj + rmsnorm after attn + quant and
+ mlp + rmsnorm + quant before attn.
+ """
+
+ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
+ allreduce_params: FlashInferFusedAllReduceParams):
+ super().__init__(dtype, device)
+ self.epsilon = epsilon
+ self.allreduce_params = allreduce_params
+
+ def register(self, pm_pass: PatternMatcherPass):
+
+ def get_inputs():
+ input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
+
+ residual = torch.empty([16, 16],
+ device=self.device,
+ dtype=self.dtype)
+ weight = torch.empty([16, 16],
+ device=self.device,
+ dtype=self.dtype)
+ quant_result = torch.empty((16, 8),
+ device=self.device,
+ dtype=torch.uint8)
+ input_global_scale = torch.empty([1, 1],
+ device=self.device,
+ dtype=torch.float32)
+ output_scale = torch.empty([128, 4],
+ device=self.device,
+ dtype=torch.int32)
+
+ return [
+ quant_result,
+ residual,
+ input,
+ output_scale,
+ weight,
+ input_global_scale,
+ ]
+
+ def pattern(quant_result: torch.Tensor, residual: torch.Tensor,
+ input: torch.Tensor, output_scale: torch.Tensor,
+ weight: torch.Tensor, input_global_scale: torch.Tensor):
+ allreduce_output = tensor_model_parallel_all_reduce(input)
+
+ fused_add_rmsnorm_out_tuple = \
+ auto_functionalized(
+ RMS_ADD_OP,
+ input=allreduce_output,
+ residual=residual,
+ weight=weight,
+ epsilon=self.epsilon)
+ quant_out_tuple = auto_functionalized(
+ STATIC_FP4_QUANT_OP,
+ output=quant_result,
+ input=fused_add_rmsnorm_out_tuple[1],
+ output_scale=output_scale,
+ input_scale=input_global_scale)
+
+ # quant_out, allreduce_output, output_scale
+ return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[
+ 2], quant_out_tuple[2]
+
+ def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
+ input: torch.Tensor, output_scale: torch.Tensor,
+ weight: torch.Tensor,
+ input_global_scale: torch.Tensor):
+ allreduce = auto_functionalized(
+ flashinfer_trtllm_fused_allreduce_norm,
+ allreduce_in=input,
+ residual=residual,
+ norm_out=None,
+ quant_out=quant_result,
+ scale_out=output_scale,
+ rms_gamma=weight,
+ rms_eps=self.epsilon,
+ pattern_code=flashinfer_comm.AllReduceFusionPattern.
+ kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
+ scale_factor=input_global_scale,
+ **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
+ )
+ # quant_out, rms_norm_residual, output_scale
+ return allreduce[4], allreduce[2], allreduce[5]
+
+ pm.register_replacement(pattern, replacement, get_inputs(),
+ pm.fwd_only, pm_pass)
+
+
class AllReduceFusionPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
@@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass):
self.tp_size,
)
return
-
+ max_num_token = min(
+ _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
+ (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
+ config.compilation_config.pass_config.
+ fi_allreduce_fusion_max_token_num)
self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
- max_token_num=config.compilation_config.pass_config.
- fi_allreduce_fusion_max_token_num,
+ max_token_num=max_num_token,
hidden_dim=self.hidden_dim,
group=self.group,
use_fp32_lamport=use_fp32_lamport,
@@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass):
rank=rank,
world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport,
- max_token_num=config.compilation_config.pass_config.
- fi_allreduce_fusion_max_token_num,
- )
+ max_token_num=max_num_token,
+ # fuse rms norm static fp8 quant fused op
+ # in fallback path, when we don't use flashinfer
+ fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
for epsilon in [1e-5, 1e-6]:
- AllReduceRMSNORMPattern(
+ AllReduceFusedRMSNormStaticQuantFP8Pattern(
+ epsilon,
+ self.model_dtype,
+ self.device,
+ self.allreduce_params,
+ ).register(self.patterns)
+ AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
+ epsilon,
+ self.model_dtype,
+ self.device,
+ self.allreduce_params,
+ ).register(self.patterns)
+ if current_platform.has_device_capability(100):
+ AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
+ epsilon,
+ self.model_dtype,
+ self.device,
+ self.allreduce_params,
+ ).register(self.patterns)
+ AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
+ epsilon,
+ self.model_dtype,
+ self.device,
+ self.allreduce_params,
+ ).register(self.patterns)
+ AllReduceRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
@@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass):
self.allreduce_params,
).register(self.patterns)
+ # WARNING: This is a hack to clear the pattern matcher cache
+ # and allow multiple values of epsilon.
+ torch._inductor.pattern_matcher._seen_patterns.clear()
+
self.disabled = False
def __call__(self, graph: fx.Graph):
@@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass):
if self.disabled:
return
if flashinfer_comm is not None:
- flashinfer_comm.trtllm_destroy_ipc_workspace(
+ flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
self.ipc_handles, self.group)
diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py
index baf98306ad241..cb629ea284fa1 100644
--- a/vllm/compilation/decorators.py
+++ b/vllm/compilation/decorators.py
@@ -108,7 +108,7 @@ def support_torch_compile(
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- - if it is a single integer (can be negative), the corresponding dimension
+ - if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
diff --git a/vllm/config.py b/vllm/config.py
index dd6ff26c186c7..6908c5a121dae 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -4062,7 +4062,7 @@ class PassConfig:
"""Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False
"""Whether to enable flashinfer allreduce fusion."""
- fi_allreduce_fusion_max_token_num: int = 1024
+ fi_allreduce_fusion_max_token_num: int = 16384
"""Max number of tokens to used in flashinfer allreduce fusion."""
# TODO(luka) better pass enabling system.
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index 88c8aa929b78d..099e456aa486f 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
+from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
__all__ = [
@@ -40,4 +41,5 @@ __all__ = [
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
+ "Step3ToolParser",
]
diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
new file mode 100644
index 0000000000000..a20d18eb52544
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
@@ -0,0 +1,296 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import contextlib
+import json
+from collections.abc import Sequence
+from typing import Any, Optional, Union
+
+import regex as re
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module(["step3"])
+class Step3ToolParser(ToolParser):
+ """
+ Tool parser for a model that uses a specific XML-like format for tool calls.
+ This version uses a robust, stateful, cursor-based streaming parser and
+ consolidates tool arguments into a single message.
+ """
+
+ TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
+ TOOL_CALLS_END = "<|tool_calls_end|>"
+ TOOL_CALL_BEGIN = "<|tool_call_begin|>"
+ TOOL_CALL_END = "<|tool_call_end|>"
+ TOOL_SEP = "<|tool_sep|>"
+ SPECIAL_TOKENS = [
+ TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
+ ]
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+ self.position = 0
+ # Explicit state flags for robust streaming
+ self.tool_block_started = False
+ self.tool_block_finished = False
+
+ def adjust_request(
+ self, request: ChatCompletionRequest) -> ChatCompletionRequest:
+ if request.tools and request.tool_choice != 'none':
+ request.skip_special_tokens = False
+ return request
+
+ @staticmethod
+ def _parse_steptml_invoke(
+ action_text: str
+ ) -> tuple[Optional[str], Optional[dict[str, str]]]:
+ func_name_match = re.search(r'',
+ action_text)
+ if not func_name_match:
+ return None, None
+ func_name = func_name_match.group(1)
+
+ params: dict[str, str] = {}
+ param_matches = re.findall(
+ r'([^<]*)',
+ action_text)
+ for name, value in param_matches:
+ params[name] = value.strip()
+ return func_name, params
+
+ def _cast_arguments(
+ self,
+ func_name: str,
+ params: dict[str, Any],
+ request: ChatCompletionRequest,
+ ) -> dict[str, Any]:
+ for tool in request.tools or []:
+ if tool.function.name == func_name:
+ schema = tool.function.parameters or {}
+ properties = schema.get("properties", {})
+ for key, value in params.items():
+ if not isinstance(value, str):
+ continue
+ prop = properties.get(key, {})
+ typ = prop.get("type")
+ if typ == "string":
+ params[key] = value.strip()
+ elif typ == "integer":
+ with contextlib.suppress(ValueError):
+ params[key] = int(value)
+ elif typ == "number":
+ with contextlib.suppress(ValueError):
+ params[key] = float(value)
+ elif typ == "boolean":
+ lower_val = value.lower()
+ params[key] = lower_val == "true" if lower_val in (
+ "true", "false") else value
+ elif typ == "null":
+ params[key] = None if value.lower(
+ ) == "null" else value
+ break
+ return params
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ # The main loop processes the stream from the last known position.
+ while True:
+ if self.position >= len(current_text):
+ return None # We've processed the entire stream.
+
+ unprocessed_text = current_text[self.position:]
+
+ # STATE: After all tools are done, all subsequent text is content.
+ if self.tool_block_finished:
+ self.position = len(current_text)
+ return DeltaMessage(content=unprocessed_text)
+
+ # STATE: Before the tool block has started.
+ if not self.tool_block_started:
+ if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
+ self.position += len(self.TOOL_CALLS_BEGIN)
+ self.tool_block_started = True
+ continue # Token consumed, re-loop.
+
+ start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
+ if start_pos == -1:
+ if self.TOOL_CALLS_BEGIN.startswith(
+ unprocessed_text.strip()) and unprocessed_text:
+ return None # It's a prefix, wait.
+ self.position = len(current_text)
+ return DeltaMessage(content=unprocessed_text)
+ else:
+ content = unprocessed_text[:start_pos]
+ self.position += len(content)
+ return DeltaMessage(content=content)
+
+ # STATE: Inside the main tool block.
+ offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
+ unprocessed_text = unprocessed_text.lstrip()
+ self.position += offset
+
+ if unprocessed_text.startswith(self.TOOL_CALLS_END):
+ self.position += len(self.TOOL_CALLS_END)
+ self.tool_block_finished = True
+ self.current_tool_id = -1
+ continue
+
+ # Check if we are between tool calls.
+ tool_finished = (
+ self.current_tool_id != -1 and
+ self.prev_tool_call_arr[self.current_tool_id].get("finished"))
+ if self.current_tool_id == -1 or tool_finished:
+ if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
+ self.position += len(self.TOOL_CALL_BEGIN)
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ else:
+ self.current_tool_id += 1
+ self.current_tool_name_sent = False
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ self.prev_tool_call_arr[
+ self.current_tool_id]["finished"] = False
+ continue
+
+ if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
+ return None
+
+ # STATE: Parsing an active tool call.
+ if self.current_tool_id != -1 and not self.prev_tool_call_arr[
+ self.current_tool_id].get("finished", False):
+ end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
+ if end_tool_pos == -1:
+ tool_body = unprocessed_text
+ else:
+ tool_body = unprocessed_text[:end_tool_pos]
+
+ if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
+ tool_body):
+ return None
+
+ function_name, arguments = self._parse_steptml_invoke(
+ tool_body)
+ if not function_name:
+ return None
+
+ tool_call_arr = {
+ "name": function_name,
+ "parameters": arguments or {}
+ }
+
+ # Send the function name as soon as it's parsed.
+ if not self.current_tool_name_sent:
+ self.current_tool_name_sent = True
+ self.prev_tool_call_arr[self.current_tool_id].update(
+ tool_call_arr)
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name))
+ ])
+
+ # Update our internal state with the latest parsed arguments.
+ self.prev_tool_call_arr[
+ self.current_tool_id].update( # noqa: E501
+ tool_call_arr)
+
+ # Only send arguments when the tool call is complete.
+ if end_tool_pos != -1:
+ self.position += end_tool_pos + len(self.TOOL_CALL_END)
+ self.prev_tool_call_arr[
+ self.current_tool_id]["finished"] = True
+
+ final_args = self._cast_arguments(
+ function_name,
+ tool_call_arr.get("parameters", {}), # type: ignore
+ request)
+ if final_args:
+ final_args_json = json.dumps(final_args,
+ ensure_ascii=False)
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=final_args_json))
+ ])
+
+ # If tool is not finished, return None to wait for more tokens.
+ return None
+
+ return None
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ if self.TOOL_CALLS_BEGIN not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
+ if self.TOOL_CALLS_END not in rest:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
+ content = (pre_text + post_text).strip()
+
+ tool_calls: list[ToolCall] = []
+ call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
+
+ for part in call_parts:
+ if not part or self.TOOL_CALL_END not in part:
+ continue
+
+ call_content = part.split(self.TOOL_CALL_END, 1)[0]
+ if self.TOOL_SEP not in call_content:
+ continue
+
+ type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
+ if type_part.strip() != "function":
+ continue
+
+ function_name, params_dict = self._parse_steptml_invoke(
+ invoke_part)
+
+ if function_name and params_dict is not None:
+ params_dict = self._cast_arguments(function_name, params_dict,
+ request)
+ params_str = json.dumps(params_dict, ensure_ascii=False)
+ tool_calls.append(
+ ToolCall(function=FunctionCall(name=function_name,
+ arguments=params_str)))
+ if tool_calls:
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content if content else None)
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py
index 911f0036c2dd6..2e026d582a6de 100644
--- a/vllm/model_executor/models/deepseek_mtp.py
+++ b/vllm/model_executor/models/deepseek_mtp.py
@@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
+ ("fused_qkv_a_proj", "q_a_proj", 0),
+ ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
@@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
+
+ # QKV fusion is optional, fall back to normal
+ # weight loading if it's not enabled
+ if ((param_name == "fused_qkv_a_proj")
+ and name not in params_dict):
+ continue
+
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py
index 470e701d98013..60098209c39ac 100644
--- a/vllm/model_executor/models/llama4.py
+++ b/vllm/model_executor/models/llama4.py
@@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
+ self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py
index 222ab5dfaee4a..ece490ff2f2a8 100644
--- a/vllm/model_executor/models/llama4_eagle.py
+++ b/vllm/model_executor/models/llama4_eagle.py
@@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index
+from vllm.multimodal.inputs import NestedTensors
-from .utils import AutoWeightsLoader, maybe_prefix
+from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
logger = init_logger(__name__)
@@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
- input_embeds = self.embed_tokens(input_ids)
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = self.fc(
- torch.cat((input_embeds, hidden_states), dim=-1))
+ torch.cat((inputs_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
hidden_states, residual = layer(
@@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
- return self.model(input_ids, positions, hidden_states)
+ return self.model(input_ids, positions, hidden_states, inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
@@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
+
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ multimodal_embeddings,
+ self.config.image_token_index,
+ )
+
+ return inputs_embeds
diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py
index c7690604c1d09..a4933b77e3a53 100644
--- a/vllm/model_executor/models/llama_eagle.py
+++ b/vllm/model_executor/models/llama_eagle.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
+from typing import Optional
import torch
import torch.nn as nn
@@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
+ if inputs_embeds is not None:
+ raise NotImplementedError(
+ f"{type(self).__name__} does not support multimodal inputs yet."
+ )
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py
index 7fc9fe2ebb6f6..71275f0d58579 100644
--- a/vllm/model_executor/models/llama_eagle3.py
+++ b/vllm/model_executor/models/llama_eagle3.py
@@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
+ if inputs_embeds is not None:
+ raise NotImplementedError(
+ f"{type(self).__name__} does not support multimodal inputs yet."
+ )
return self.model(input_ids, positions, hidden_states)
def compute_logits(
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 51831a770347a..848c04b9b32f7 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
+ "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
@@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
+ "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py
new file mode 100644
index 0000000000000..47d2af5c2a140
--- /dev/null
+++ b/vllm/model_executor/models/step3_text.py
@@ -0,0 +1,521 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only Jurassic model."""
+from collections.abc import Iterable
+from typing import Any, Optional
+
+import torch
+from torch import nn
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, ModelConfig, VllmConfig
+from vllm.distributed import (get_pp_group,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsPP
+from .utils import (PPMissingLayer, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers)
+
+logger = init_logger(__name__)
+
+
+class FusedMoEBlock(nn.Module):
+
+ def __init__(self,
+ config: ModelConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+
+ if self.tp_size > config.moe_num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {config.moe_num_experts}.")
+
+ self.experts = FusedMoE(num_experts=config.moe_num_experts,
+ top_k=config.moe_top_k,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_expert_weight,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts")
+ self.gate = ReplicatedLinear(config.hidden_size,
+ config.moe_num_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ orig_shape = hidden_states.shape
+ hidden_dim = hidden_states.shape[-1]
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ router_logits, _ = self.gate(hidden_states)
+
+ final_hidden_states = self.experts(hidden_states=hidden_states,
+ router_logits=router_logits)
+ if self.tp_size > 1:
+ final_hidden_states = tensor_model_parallel_all_reduce(
+ final_hidden_states)
+
+ return final_hidden_states.view(orig_shape)
+
+
+class Step3TextMLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size, [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj")
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+ self.hidden_size = hidden_size
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_up, _ = self.gate_up_proj(hidden_states)
+ intermediate_act = self.act_fn(gate_up)
+ output, _ = self.down_proj(intermediate_act)
+ return output
+
+
+class Step3TextAttention(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ norm_eps: float,
+ rope_theta: int,
+ share_q_dim: Optional[int] = None,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embedding: int = 8192,
+ head_dim: int = 256,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+
+ if num_kv_heads != 1:
+ raise ValueError(f"Step3TextAttention num_kv_heads must be 1, "
+ f"but got {num_kv_heads}.")
+ self.num_kv_heads = num_kv_heads
+
+ self.head_dim = head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.q_size = share_q_dim if share_q_dim else self.head_dim
+
+ self.qkv_proj = ReplicatedLinear(
+ hidden_size,
+ self.q_size + self.kv_size * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+ self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
+ self.wq = ColumnParallelLinear(
+ self.q_size,
+ self.head_dim * self.total_num_heads,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.wq",
+ )
+ self.rotary_emb = get_rope(self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embedding,
+ base=rope_theta,
+ rope_scaling=rope_scaling)
+ scaling = self.head_dim**-0.5
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ scaling,
+ self.num_kv_heads,
+ cache_config=cache_config,
+ prefix=f"{prefix}.attn")
+
+ def forward(self, positions: torch.Tensor,
+ hidden_states: torch.Tensor) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.inter_norm(q)
+ q = self.wq(q)[0]
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ residual, _ = self.o_proj(attn_output)
+ return residual
+
+
+class Step3TextDecoderLayer(nn.Module):
+
+ def __init__(self,
+ config: ModelConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "") -> None:
+ super().__init__()
+ config = config.hf_config
+ self.hidden_size = config.hidden_size
+ rope_scaling = getattr(config, "rope_scaling", None)
+
+ self.self_attn = Step3TextAttention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=1,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ norm_eps=config.rms_norm_eps,
+ max_position_embedding=config.max_position_embedding,
+ head_dim=config.head_dim,
+ share_q_dim=config.share_q_dim,
+ rope_theta=config.rope_theta,
+ rope_scaling=rope_scaling,
+ prefix=f"{prefix}.self_attn")
+
+ layer_idx = int(prefix.split("layers.")[1].split(".")[0])
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
+ if moe_layers_enum is not None:
+ moe_layers_idx = [
+ int(i) for i in moe_layers_enum.strip().split(',')
+ ]
+ else:
+ # Default to 1dense.
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
+
+ if layer_idx in moe_layers_idx:
+ self.moe = FusedMoEBlock(config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.moe")
+ self.share_expert = Step3TextMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.share_expert_dim,
+ hidden_act="silu",
+ quant_config=quant_config,
+ prefix=f"{prefix}.share_expert")
+ self.use_moe = True
+ else:
+ self.mlp = Step3TextMLP(hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act="silu",
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+ self.use_moe = False
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+
+ def forward(
+ self, positions: torch.Tensor, hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+
+ if self.use_moe:
+ share_output = self.share_expert(hidden_states)
+ moe_output = self.moe(hidden_states)
+ hidden_states = share_output + moe_output
+ else:
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states, residual
+
+
+@support_torch_compile
+class Step3TextModel(nn.Module):
+
+ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ self.vocab_size = config.vocab_size
+ self.config = config
+
+ if get_pp_group().is_first_rank or (config.tie_word_embeddings
+ and get_pp_group().is_last_rank):
+ self.embed_tokens = VocabParallelEmbedding(
+ self.vocab_size,
+ config.hidden_size,
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Step3TextDecoderLayer(config=vllm_config.
+ model_config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix),
+ prefix=f"{prefix}.layers",
+ )
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(["hidden_states"],
+ config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual,
+ })
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Step3TextForCausalLM(nn.Module, SupportsPP):
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ lora_config = vllm_config.lora_config
+ self.config = config
+ self.vllm_config = vllm_config
+
+ self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
+
+ if get_pp_group().is_last_rank:
+ self.unpadded_vocab_size = config.vocab_size
+ if lora_config:
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+ self.lm_head = ParallelLMHead(
+ self.unpadded_vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE
+ if not lora_config else lora_config.lora_vocab_padding_size,
+ )
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+ config.vocab_size)
+ self.sampler = get_sampler()
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+ def forward(self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None):
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: Optional[torch.Tensor],
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ qkv_params_mapping = [
+ # (param_name, shard_name, relative_start_idx, relative_end_idx)
+ (".qkv_proj", ".q_proj", 0, self.config.share_q_dim /
+ (self.config.share_q_dim + self.config.head_dim * 2)),
+ (".qkv_proj", ".k_proj", self.config.share_q_dim /
+ (self.config.share_q_dim + self.config.head_dim * 2),
+ (self.config.share_q_dim + self.config.head_dim) /
+ (self.config.share_q_dim + self.config.head_dim * 2)),
+ (".qkv_proj", ".v_proj",
+ (self.config.share_q_dim + self.config.head_dim) /
+ (self.config.share_q_dim + self.config.head_dim * 2),
+ (self.config.share_q_dim + self.config.head_dim * 2) /
+ (self.config.share_q_dim + self.config.head_dim * 2)),
+ ]
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ expert_params_mapping = [
+ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
+ (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
+ (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
+ ]
+
+ disable_moe_stacked_params = [
+ data[1] for data in expert_params_mapping
+ ]
+
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ if any(disable_moe_stacked_param in name
+ for disable_moe_stacked_param in
+ disable_moe_stacked_params):
+ continue
+ name = name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(name)
+ break
+ else:
+ for mapping in expert_params_mapping:
+ param_name, weight_name, shard_id = mapping
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ # Skip loading extra bias for GPTQ models.
+ if ((name.endswith(".bias") or name.endswith("_bias"))
+ and name not in params_dict):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ for expert_id in range(loaded_weight.shape[0]):
+ loaded_weight_expert = loaded_weight[expert_id]
+ weight_loader(param,
+ loaded_weight_expert,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id)
+ loaded_params.add(name)
+ break
+ else:
+ for (param_name, weight_name, start_idx,
+ end_idx) in qkv_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ dim = param.shape[param.output_dim]
+ begin_idx = int(start_idx * dim)
+ end_idx = int(end_idx * dim)
+ param_slice = param.narrow(param.output_dim, begin_idx,
+ end_idx - begin_idx)
+ param_slice.copy_(loaded_weight)
+ loaded_params.add(name)
+ break
+ else:
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py
new file mode 100644
index 0000000000000..363c12a4bf2b8
--- /dev/null
+++ b/vllm/model_executor/models/step3_vl.py
@@ -0,0 +1,1052 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from functools import cached_property
+from itertools import product
+from math import ceil, sqrt
+from typing import Any, Literal, Optional, TypedDict, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from transformers import BatchFeature, PretrainedConfig, TensorType
+
+from vllm.config import VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs, NestedTensors)
+from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate, PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs import Step3VisionEncoderConfig
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
+
+
+class Step3VLImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values: torch.Tensor
+ patch_pixel_values: Optional[torch.Tensor]
+ num_patches: list[int]
+
+
+class Step3VLImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ image_embeds: torch.Tensor
+
+
+Step3VLImageInputs = Union[Step3VLImagePixelInputs,
+ Step3VLImageEmbeddingInputs]
+
+ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
+
+MAX_IMAGE_SIZE: int = 3024
+
+
+class Step3VisionProcessor:
+
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+ patch_size = patch_size if patch_size is not None else size
+
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ transforms.Resize(
+ (size, size),
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
+ == "bicubic" else InterpolationMode.BILINEAR,
+ antialias=True),
+ ])
+
+ self.patch_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ transforms.Resize(
+ (patch_size, patch_size),
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
+ == "bicubic" else InterpolationMode.BILINEAR,
+ antialias=True),
+ ]) if patch_size is not None else None
+
+ def __call__(self, image, is_patch=False):
+ if is_patch:
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
+ else:
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
+
+
+class ImagePatcher:
+
+ def determine_window_size(self, long: int, short: int) -> int:
+ if long <= 728:
+ return short if long / short > 1.5 else 0
+ return min(short, 504) if long / short > 4 else 504
+
+ def slide_window(
+ self,
+ width: int,
+ height: int,
+ sizes: list[tuple[int, int]],
+ steps: list[tuple[int, int]],
+ img_rate_thr: float = 0.6,
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
+ windows = []
+ # Sliding windows.
+ for size, step in zip(sizes, steps):
+ size_w, size_h = size
+ step_w, step_h = step
+
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
+ 1)
+ x_start = [step_w * i for i in range(x_num)]
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
+ x_start[-1] = width - size_w
+
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
+ step_h + 1)
+ y_start = [step_h * i for i in range(y_num)]
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
+ y_start[-1] = height - size_h
+
+ start = np.array(list(product(y_start, x_start)), dtype=int)
+ start[:, [0, 1]] = start[:, [1, 0]]
+ windows.append(np.concatenate([start, start + size], axis=1))
+ windows = np.concatenate(windows, axis=0)
+
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
+
+ def square_pad(self, img: Image.Image) -> Image.Image:
+ w, h = img.size
+ if w == h:
+ return img
+ size = max(w, h)
+ padded = Image.new(img.mode, (size, size), 0)
+ padded.paste(img, (0, 0))
+ return padded
+
+ def get_image_size_for_padding(self, img_width: int,
+ img_height: int) -> tuple[int, int]:
+ ratio = img_width / img_height
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
+ new_size = max(img_height, img_width)
+ return new_size, new_size
+ return img_width, img_height
+
+ def get_image_size_for_preprocess(self, img_width: int,
+ img_height: int) -> tuple[int, int]:
+
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
+ img_width = int(img_width * scale_factor)
+ img_height = int(img_height * scale_factor)
+ return img_width, img_height
+
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
+ window_size: int):
+ w_ratio = img_width / window_size
+ h_ratio = img_height / window_size
+
+ if w_ratio < 1:
+ width_new = img_width
+ else:
+ decimal_w = w_ratio - img_width // window_size
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
+ width_new = window_size * w_ratio
+ if h_ratio < 1:
+ height_new = img_height
+ else:
+ decimal_h = h_ratio - img_height // window_size
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
+ height_new = window_size * h_ratio
+ return int(width_new), int(height_new)
+
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
+ target = img.crop((j, i, j + tw, i + th))
+ return target
+
+ def get_num_patches(self, img_width: int,
+ img_height: int) -> tuple[int, int]:
+ img_width, img_height = self.get_image_size_for_padding(
+ img_width, img_height)
+ img_width, img_height = self.get_image_size_for_preprocess(
+ img_width, img_height)
+ window_size = self.determine_window_size(max(img_height, img_width),
+ min(img_height, img_width))
+ if window_size == 0:
+ return 0, 0
+ else:
+ img_width, img_height = self.get_image_size_for_crop(
+ img_width, img_height, window_size)
+ center_list, (x_num, y_num) = self.slide_window(
+ img_width, img_height, [(window_size, window_size)],
+ [(window_size, window_size)])
+ full_rows = (len(center_list) - 1) // x_num + 1
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
+ full_rows -= 1
+ return len(center_list), full_rows
+
+ def __call__(
+ self, img: Image.Image
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
+ img_width, img_height = img.size
+ new_img_width, new_img_height = self.get_image_size_for_padding(
+ img_width, img_height)
+ if new_img_width != img_width or new_img_height != img_height:
+ img = self.square_pad(img)
+ img_width, img_height = img.size
+
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
+ img_width, img_height)
+ img = img.resize((new_img_width, new_img_height),
+ Image.Resampling.BILINEAR)
+ window_size = self.determine_window_size(
+ max(new_img_height, new_img_width),
+ min(new_img_height, new_img_width))
+
+ if window_size == 0:
+ return img, [], None
+ else:
+ new_img_width, new_img_height = self.get_image_size_for_crop(
+ new_img_width, new_img_height, window_size)
+ if (new_img_width, new_img_height) != (img_width, img_height):
+ img_for_crop = img.resize((new_img_width, new_img_height),
+ Image.Resampling.BILINEAR)
+ else:
+ img_for_crop = img
+
+ patches = []
+ newlines = []
+ center_list, (x_num, y_num) = self.slide_window(
+ new_img_width, new_img_height, [(window_size, window_size)],
+ [(window_size, window_size)])
+ for patch_id, center_lf_point in enumerate(center_list):
+ x, y, patch_w, patch_h = center_lf_point
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
+ patch_w)
+ patches.append(big_patch)
+ if (patch_id + 1) % x_num == 0:
+ newlines.append(patch_id)
+
+ if newlines and newlines[-1] == len(patches) - 1:
+ newlines.pop()
+
+ return img, patches, [i in newlines for i in range(len(patches))
+ ] if len(patches) > 0 else None
+
+
+class Step3VLProcessor:
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: AnyTokenizer,
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.image_size = 728
+ self.patch_size = 504
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
+ "bilinear",
+ self.patch_size)
+
+ self.num_image_feature_size = 169
+ self.num_patch_feature_size = 81
+ self.image_token = ""
+ self.image_feature_placeholder = (self.image_token *
+ self.num_image_feature_size)
+ self.patch_feature_placeholder = (self.image_token *
+ self.num_patch_feature_size)
+
+ self.patcher = ImagePatcher()
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[self.image_token]
+
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
+ num_patches, num_newlines = self.patcher.get_num_patches(
+ img_width, img_height)
+
+ return num_patches * (
+ self.num_patch_feature_size +
+ 2) + self.num_image_feature_size + 2 + num_newlines
+
+ def _split_images(self,
+ images: list[Image.Image]) -> list[ImageWithPatches]:
+ result = []
+ for img in images:
+ result.append(self.patcher(img))
+ return result
+
+ def _convert_images_to_pixel_values(
+ self,
+ images: list[Image.Image],
+ is_patch: bool = False,
+ ) -> list[torch.Tensor]:
+ return [
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
+ for img in images
+ ]
+
+ def _get_patch_repl(
+ self,
+ num_patches: int,
+ patch_newline_mask: list[bool] | None,
+ ) -> tuple[str, list[int]]:
+ text = ""
+ token_ids = []
+ for i in range(num_patches):
+ assert len(patch_newline_mask) == num_patches
+ text += f"{self.patch_feature_placeholder}"
+ token_ids.extend(
+ [self.tokenizer.convert_tokens_to_ids("")] +
+ [self.image_token_id] * self.num_patch_feature_size +
+ [self.tokenizer.convert_tokens_to_ids("")])
+ if patch_newline_mask and patch_newline_mask[i]:
+ text += ""
+ token_ids.append(
+ self.tokenizer.convert_tokens_to_ids(""))
+ return text, token_ids
+
+ def _get_image_repl(
+ self,
+ num_images: int,
+ ) -> tuple[str, list[int]]:
+ text = f"{self.image_feature_placeholder}"
+ token_ids = [
+ self.tokenizer.convert_tokens_to_ids("")
+ ] + [self.image_token_id] * self.num_image_feature_size + [
+ self.tokenizer.convert_tokens_to_ids("")
+ ]
+ return text * num_images, token_ids * num_images
+
+ def _get_image_repl_features(
+ self,
+ num_images: int,
+ num_patches: int,
+ patch_new_line_idx: Optional[list[bool]],
+ ) -> tuple[str, list[int]]:
+ if num_patches > 0:
+ patch_repl, patch_repl_ids = self._get_patch_repl(
+ num_patches, patch_new_line_idx)
+ else:
+ patch_repl = ""
+ patch_repl_ids = []
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
+
+ def replace_placeholder(self, text: str, placeholder: str,
+ repls: list[str]) -> str:
+ parts = text.split(placeholder)
+
+ if len(parts) - 1 != len(repls):
+ raise ValueError(
+ "The number of placeholders does not match the number of replacements." # noqa: E501
+ )
+
+ result = [parts[0]]
+ for i, repl in enumerate(repls):
+ result.append(repl)
+ result.append(parts[i + 1])
+
+ return "".join(result)
+
+ def __call__(
+ self,
+ text: Optional[Union[str, list[str]]] = None,
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ ) -> BatchFeature:
+ if text is None:
+ text = []
+ if not isinstance(text, list):
+ text = [text]
+ if images is None:
+ images = []
+ if not isinstance(images, list):
+ images = [images]
+
+ if len(images) == 0:
+ image_inputs = {}
+ text_inputs = self.tokenizer(text)
+ else:
+ splitted_images_data = self._split_images(images)
+ pixel_values_lst = []
+ patch_pixel_values_lst = []
+ patch_newline_mask_lst = []
+ image_repl_str_lst = []
+ image_repl_ids_lst = []
+ num_patches = []
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
+ pixel_values_lst.extend(
+ self._convert_images_to_pixel_values([raw_img]))
+
+ if len(img_patches) > 0:
+ patch_pixel_values_lst.extend(
+ self._convert_images_to_pixel_values(img_patches,
+ is_patch=True))
+ num_patches.append(len(img_patches))
+
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
+ 1, len(img_patches), patch_newline_mask)
+ image_repl_str_lst.append(image_repl_str)
+ image_repl_ids_lst.extend(image_repl_ids)
+
+ if patch_newline_mask is not None:
+ patch_newline_mask_lst.extend(patch_newline_mask)
+
+ image_inputs = {
+ "pixel_values": torch.cat(pixel_values_lst),
+ "num_patches": num_patches,
+ }
+ if patch_pixel_values_lst:
+ image_inputs["patch_pixel_values"] = torch.cat(
+ patch_pixel_values_lst)
+ if patch_newline_mask_lst:
+ image_inputs["patch_newline_mask"] = torch.tensor(
+ patch_newline_mask_lst, dtype=torch.bool)
+
+ text = [
+ self.replace_placeholder(t, self.image_token,
+ image_repl_str_lst) for t in text
+ ]
+ text_inputs = self.tokenizer(text)
+
+ return BatchFeature(
+ {
+ **text_inputs,
+ **image_inputs,
+ },
+ tensor_type=return_tensors,
+ )
+
+
+class Step3VLProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_processor(self) -> Step3VLProcessor:
+ return Step3VLProcessor(
+ self.get_hf_config(),
+ self.get_tokenizer(),
+ )
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_max_image_tokens(self) -> int:
+ hf_processor = self.get_hf_processor()
+ return hf_processor.get_num_image_tokens(
+ self.get_image_size_with_most_features().width,
+ self.get_image_size_with_most_features().height)
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {"image": self.get_max_image_tokens()}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ return ImageSize(3024, 3024)
+
+ def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
+ if len(mm_data) != 1 or "image" not in mm_data:
+ raise ValueError(
+ "mm_data could only contain one key 'image' for steo1o")
+
+ image_data = mm_data["image"]
+ if not isinstance(image_data, (list, tuple)):
+ image_data = [image_data]
+
+ return sum(self.get_hf_processor().get_num_image_tokens(
+ img.width, img.height) for img in image_data)
+
+
+class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ return "" * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+ num_images = mm_counts.get("image", 0)
+
+ return {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+
+
+class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]
+ ):
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_placeholder_token_id = hf_processor.image_token_id
+ batch_num_patches = out_mm_kwargs["num_patches"].tolist()
+
+ def get_replacement_step1o(item_idx: int):
+ img_out = out_mm_kwargs.get_item("image", item_idx)
+ num_patches = batch_num_patches[item_idx]
+ if num_patches > 0:
+ patch_newline_mask = img_out["patch_newline_mask"].data.tolist(
+ )
+ image_repl_ids = hf_processor._get_image_repl_features(
+ 1, num_patches, patch_newline_mask)[1]
+ else:
+ image_repl_ids = hf_processor._get_image_repl_features(
+ 1, 0, None)[1]
+ return PromptUpdateDetails.select_token_id(
+ seq=image_repl_ids,
+ embed_token_id=image_placeholder_token_id,
+ )
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[image_placeholder_token_id],
+ replacement=get_replacement_step1o,
+ )
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ num_patches = hf_inputs.get("num_patches", torch.empty(0))
+
+ return dict(
+ pixel_values=MultiModalFieldConfig.batched("image"),
+ patch_pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_patches),
+ num_patches=MultiModalFieldConfig.batched("image"),
+ patch_newline_mask=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_patches),
+ )
+
+
+def get_abs_pos(abs_pos, tgt_size):
+ dim = abs_pos.size(-1)
+ abs_pos_new = abs_pos.squeeze(0)
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
+
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
+ tgt_size = int(math.sqrt(tgt_size))
+ dtype = abs_pos.dtype
+
+ if src_size != tgt_size:
+ old_pos_embed = old_pos_embed.view(1, src_size, src_size,
+ dim).permute(0, 3, 1,
+ 2).contiguous()
+ old_pos_embed = old_pos_embed.to(torch.float32)
+ new_pos_embed = F.interpolate(
+ old_pos_embed,
+ size=(tgt_size, tgt_size),
+ mode='bicubic',
+ antialias=True,
+ align_corners=False,
+ ).to(dtype)
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
+ new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
+ vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
+ vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1,
+ dim)
+ return vision_pos_embed
+ else:
+ return abs_pos
+
+
+class Step3VisionEmbeddings(nn.Module):
+
+ def __init__(self, config: Step3VisionEncoderConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=True,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.pad_tp_size = 4 # hard code for padding
+ # To load the pretrained weights, we still use P+1 as the seqlen
+ self.position_embedding = torch.nn.Embedding(self.num_patches + 1,
+ self.embed_dim)
+ self.register_buffer("position_ids",
+ torch.arange(self.num_patches + 1).expand(
+ (1, -1)),
+ persistent=False)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(
+ pixel_values) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ # pad
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + get_abs_pos(
+ self.position_embedding(self.position_ids), patch_embeds.size(1))
+ embeddings = torch.cat([
+ embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1,
+ 1), embeddings
+ ],
+ dim=1)
+ return embeddings
+
+
+class Step3VisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self,
+ config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.total_num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.total_num_heads
+
+ self.scale = self.head_dim**-0.5
+
+ tp_size = get_tensor_model_parallel_world_size()
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.qkv_proj = QKVParallelLinear(self.embed_dim,
+ self.head_dim,
+ self.total_num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+ self.out_proj = RowParallelLinear(self.embed_dim,
+ self.embed_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads,
+ self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ):
+ """Input shape: Batch x Time x Channel"""
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
+ q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ attn_output = F.scaled_dot_product_attention(q,
+ k,
+ v,
+ scale=self.scale,
+ is_causal=False)
+ attn_output = attn_output.transpose(1, 2).reshape(
+ bsz, tgt_len, self.num_heads * self.head_dim)
+
+ attn_output, _ = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Step3VisionMLP(nn.Module):
+
+ def __init__(self,
+ config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.activation_fn = get_act_fn(config.hidden_act)
+ self.fc1 = ColumnParallelLinear(config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+ self.fc2 = RowParallelLinear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states, _ = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Step3VisionEncoderLayer(nn.Module):
+
+ def __init__(self,
+ config: Step3VisionEncoderConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Step3VisionAttention(config,
+ quant_config,
+ prefix=f"{prefix}.self_attn")
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+ self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.FloatTensor:
+ hidden_states = hidden_states + self.layer_norm1(
+ self.self_attn(hidden_states))
+ hidden_states = hidden_states + self.layer_norm2(
+ self.mlp(hidden_states))
+ return hidden_states
+
+
+class Step3VisionEncoder(nn.Module):
+
+ def __init__(self,
+ config: Step3VisionEncoderConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ Step3VisionEncoderLayer(config,
+ quant_config,
+ prefix=f"{prefix}.layers.{i}")
+ for i in range(config.num_hidden_layers)
+ ])
+
+ def forward(
+ self,
+ inputs_embeds,
+ ):
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(hidden_states)
+ return hidden_states
+
+
+class Step3VisionTransformer(nn.Module):
+
+ def __init__(self,
+ config: Step3VisionEncoderConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.image_size = config.image_size
+ self.embeddings = Step3VisionEmbeddings(config)
+ self.transformer = Step3VisionEncoder(config,
+ quant_config,
+ prefix=f"{prefix}.transformer")
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ ):
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.transformer(inputs_embeds=hidden_states)
+ return hidden_states
+
+
+@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor,
+ info=Step3VLProcessingInfo,
+ dummy_inputs=Step3VLDummyInputsBuilder)
+class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsPP):
+
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
+ "model.": "language_model.model.",
+ "lm_head.": "language_model.lm_head.",
+ })
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return ""
+
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ self.vision_model = Step3VisionTransformer(config.vision_config,
+ None,
+ prefix=maybe_prefix(
+ prefix, "vision_model"))
+ self.vit_downsampler = nn.Conv2d(
+ config.vision_config.hidden_size,
+ config.vision_config.output_hidden_size,
+ kernel_size=2,
+ stride=config.understand_projector_stride)
+ self.vit_downsampler2 = nn.Conv2d(
+ config.vision_config.output_hidden_size,
+ config.vision_config.output_hidden_size * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ self.vit_large_projector = nn.Linear(
+ config.vision_config.output_hidden_size * 2,
+ config.hidden_size,
+ bias=config.projector_bias,
+ )
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"))
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+
+ return get_sampler()
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Step3VLImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
+ num_patches = kwargs.pop("num_patches", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ if pixel_values is not None:
+ pixel_values = flatten_bn(pixel_values, concat=True)
+ if pixel_values.dim() >= 3:
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
+ if patch_pixel_values is not None:
+ patch_pixel_values = flatten_bn(patch_pixel_values,
+ concat=True)
+ patch_pixel_values = patch_pixel_values.view(
+ -1, *patch_pixel_values.shape[-3:])
+ # Handle empty patch_pixel_values by setting to None
+ if patch_pixel_values.shape[0] == 0:
+ patch_pixel_values = None
+ num_patches = flatten_bn(num_patches, concat=True).tolist()
+
+ return Step3VLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
+ self.device) if patch_pixel_values is not None else None,
+ num_patches=num_patches,
+ )
+
+ if image_embeds is not None:
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
+ else:
+ raise ValueError(
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
+
+ return Step3VLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
+ )
+ return None
+
+ def _process_image_features(self,
+ image_features: torch.Tensor) -> torch.Tensor:
+ B, P = image_features.shape[:2]
+ HW = int(sqrt(P))
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
+ image_features = self.vit_downsampler(image_features)
+ image_features = self.vit_downsampler2(image_features)
+ n_dim = image_features.size(1)
+ image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
+ image_features = self.vit_large_projector(image_features)
+ return image_features
+
+ def _get_vision_model_output(self,
+ input_tensor: torch.Tensor) -> torch.Tensor:
+ return self.vision_model(input_tensor)[:, 4:]
+
+ def _process_image_input(
+ self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]:
+
+ if image_input["type"] == "image_embeds":
+ image_features = image_input["image_embeds"]
+ else:
+ image_features = self._get_vision_model_output(
+ image_input["pixel_values"])
+ patch_image_features = self._get_vision_model_output(
+ image_input["patch_pixel_values"]
+ ) if image_input["patch_pixel_values"] is not None else None
+ num_patches = image_input["num_patches"]
+
+ image_features = self._process_image_features(image_features)
+ patch_image_features = self._process_image_features(
+ patch_image_features) if patch_image_features is not None else None
+
+ merged_image_features = []
+ cur_patch_idx = 0
+ for i, num_patch in enumerate(num_patches):
+ cur_feature = []
+ if num_patch > 0:
+ patch_slice = patch_image_features[
+ cur_patch_idx:cur_patch_idx + num_patch]
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
+ cur_feature.append(image_features[i].view(
+ -1, image_features.shape[-1]))
+ cur_patch_idx += num_patch
+ merged_image_features.append(
+ torch.cat(cur_feature) if len(cur_feature) >
+ 1 else cur_feature[0])
+ return merged_image_features
+
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+ vision_embeddings = self._process_image_input(image_input)
+ return vision_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ if multimodal_embeddings is None:
+ inputs_embeds = self.language_model.model.get_input_embeddings(
+ input_ids)
+ else:
+ is_text = input_ids != self.config.image_token_id
+ text_ids = input_ids[is_text]
+ text_embeds = self.language_model.model.get_input_embeddings(
+ text_ids)
+ inputs_embeds = torch.empty(input_ids.shape[0],
+ text_embeds.shape[-1],
+ dtype=text_embeds.dtype,
+ device=text_embeds.device)
+ inputs_embeds[is_text] = text_embeds
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ self.config.image_token_id)
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ # always pass the input via `inputs_embeds`
+ # to make sure the computation graph is consistent
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ hidden_states = self.language_model(input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds=inputs_embeds)
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
+ loader = AutoWeightsLoader(self)
+ loaded_weights = loader.load_weights(weights,
+ mapper=self.hf_to_vllm_mapper)
+ return loaded_weights
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index d61e4f11dfa29..1c3f78f2edbfb 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
+from .step3_reasoning_parser import Step3ReasoningParser
__all__ = [
"ReasoningParser",
@@ -18,4 +19,5 @@ __all__ = [
"Qwen3ReasoningParser",
"Glm4MoeModelReasoningParser",
"MistralReasoningParser",
+ "Step3ReasoningParser",
]
diff --git a/vllm/reasoning/step3_reasoning_parser.py b/vllm/reasoning/step3_reasoning_parser.py
new file mode 100644
index 0000000000000..f642ea977c580
--- /dev/null
+++ b/vllm/reasoning/step3_reasoning_parser.py
@@ -0,0 +1,109 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Sequence
+from typing import Optional, Union
+
+import regex as re
+from transformers import PreTrainedTokenizerBase
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaMessage)
+from vllm.logger import init_logger
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
+
+logger = init_logger(__name__)
+
+
+@ReasoningParserManager.register_module("step3")
+class Step3ReasoningParser(ReasoningParser):
+ """
+ Reasoning parser for Step3 model.
+
+ The Step3 model uses token to denote the end of reasoning
+ text. This parser extracts all content before as reasoning content.
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
+ super().__init__(tokenizer)
+ self.think_end_token = ""
+
+ self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
+ re.DOTALL)
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ReasoningParser "
+ "constructor during construction.")
+
+ self.think_end_token_id = self.vocab.get(self.think_end_token)
+ if self.think_end_token_id is None:
+ raise RuntimeError(
+ "Step3 reasoning parser could not locate think end "
+ "token in the tokenizer!")
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> Union[DeltaMessage, None]:
+ """
+ Extract reasoning content from a delta message.
+ Handles streaming output where previous + delta = current.
+ Uses token IDs for faster processing.
+ For text "abcxyz":
+ - 'abc' goes to reasoning_content
+ - 'xyz' goes to content
+ """
+ # Skip single special token
+ if len(delta_token_ids
+ ) == 1 and delta_token_ids[0] == self.think_end_token_id:
+ return None
+
+ if self.think_end_token_id in delta_token_ids:
+ # in delta, extract reasoning content and remaining content
+ end_index = delta_text.find(self.think_end_token)
+ reasoning_content = delta_text[:end_index]
+ content = delta_text[end_index + len(self.think_end_token):]
+ return DeltaMessage(reasoning_content=reasoning_content,
+ content=content if content else None)
+ elif self.think_end_token_id in previous_token_ids:
+ # already seen in previous text, everything is content
+ return DeltaMessage(content=delta_text)
+ else:
+ # No seen yet, everything is reasoning
+ return DeltaMessage(reasoning_content=delta_text)
+
+ def extract_reasoning_content(
+ self, model_output: str, request: ChatCompletionRequest
+ ) -> tuple[Optional[str], Optional[str]]:
+
+ # Check if the model output contains the token
+ if self.think_end_token not in model_output:
+ # If no token, everything is reasoning content
+ return model_output, None
+ else:
+ # Find the first occurrence of
+ end_index = model_output.find(self.think_end_token)
+ reasoning_content = model_output[:end_index]
+
+ # Content after token
+ content = model_output[end_index + len(self.think_end_token):]
+
+ if len(content) == 0:
+ content = None
+
+ return reasoning_content, content
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.think_end_token_id in input_ids
+
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ if self.think_end_token_id not in input_ids[:-1]:
+ return []
+ else:
+ return input_ids[input_ids.index(self.think_end_token_id) + 1:]
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 4ce56cb3a6aac..fcaa48c1392a3 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
- RWConfig, UltravoxConfig)
+ RWConfig, Step3TextConfig,
+ Step3VLConfig, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
@@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ultravox": UltravoxConfig,
+ "step3_vl": Step3VLConfig,
+ "step3_text": Step3TextConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 7c7d859e4a325..96733da726181 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
+from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
+ Step3VisionEncoderConfig,
+ Step3VLConfig)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
@@ -42,4 +45,7 @@ __all__ = [
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
"UltravoxConfig",
+ "Step3VLConfig",
+ "Step3VisionEncoderConfig",
+ "Step3TextConfig",
]
diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py
new file mode 100644
index 0000000000000..fe3c72de69d28
--- /dev/null
+++ b/vllm/transformers_utils/configs/step3_vl.py
@@ -0,0 +1,123 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any, Optional, Union
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class Step3VisionEncoderConfig(PretrainedConfig):
+ model_type = "step3_vision_encoder"
+
+ def __init__(
+ self,
+ hidden_size=1792,
+ intermediate_size=3072,
+ output_hidden_size=4096,
+ num_hidden_layers=63,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=728,
+ patch_size=14,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.output_hidden_size = output_hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ super().__init__(**kwargs)
+
+
+class Step3TextConfig(PretrainedConfig):
+ model_type = "step3_text"
+ architectures = ["Step3TextForCausalLM"]
+
+ def __init__(
+ self,
+ hidden_size: int = 7168,
+ intermediate_size: int = 18432,
+ num_attention_heads: int = 64,
+ num_attention_groups: int = 1,
+ num_hidden_layers: int = 61,
+ max_seq_len: int = 65536,
+ vocab_size: int = 128815,
+ rms_norm_eps: float = 1e-5,
+ moe_intermediate_size: int = 5120,
+ moe_num_experts: int = 48,
+ moe_top_k: int = 3,
+ rope_theta: float = 500000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embedding: int = 65536,
+ share_expert_dim: int = 5120,
+ share_q_dim: int = 2048,
+ head_dim: int = 256,
+ norm_expert_weight: bool = False,
+ moe_layers_enum: tuple[int,
+ ...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
+ 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59),
+ **kwargs,
+ ) -> None:
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_attention_groups = num_attention_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.max_seq_len = max_seq_len
+ self.vocab_size = vocab_size
+ self.rms_norm_eps = rms_norm_eps
+ self.moe_intermediate_size = moe_intermediate_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_top_k = moe_top_k
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.max_position_embedding = max_position_embedding
+ self.share_expert_dim = share_expert_dim
+ self.share_q_dim = share_q_dim
+ self.head_dim = head_dim
+ self.norm_expert_weight = norm_expert_weight
+ self.moe_layers_enum = moe_layers_enum
+
+ super().__init__(**kwargs)
+
+
+class Step3VLConfig(PretrainedConfig):
+ model_type = "step3_vl"
+
+ def __init__(
+ self,
+ vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
+ text_config: Optional[Union[dict, Step3TextConfig]] = None,
+ understand_projector_stride: int = 1,
+ projector_bias: bool = True,
+ image_token_id: int = 128001,
+ **kwargs,
+ ) -> None:
+ if vision_config is None:
+ vision_config = Step3VisionEncoderConfig()
+ elif isinstance(vision_config, dict):
+ vision_config = Step3VisionEncoderConfig(**vision_config)
+ self.vision_config = vision_config
+
+ if text_config is None:
+ text_config = Step3TextConfig()
+ elif isinstance(text_config, dict):
+ text_config = Step3TextConfig(**text_config)
+ self.text_config = text_config
+
+ self.understand_projector_stride = understand_projector_stride
+ self.projector_bias = projector_bias
+ self.hidden_size = text_config.hidden_size
+ self.image_token_id = image_token_id
+
+ super().__init__(**kwargs)
diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py
index 63f6fc276189d..302126dbe3d5f 100644
--- a/vllm/v1/spec_decode/eagle.py
+++ b/vllm/v1/spec_decode/eagle.py
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Optional
+
import numpy as np
import torch
import torch.nn as nn
@@ -51,6 +53,9 @@ class EagleProposer:
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
+ self.is_multimodal_model = vllm_config.model_config \
+ .is_multimodal_model
+
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
@@ -76,6 +81,11 @@ class EagleProposer:
device=device,
dtype=torch.int32)
+ self.inputs_embeds = torch.zeros(
+ (self.max_num_tokens, self.hidden_size),
+ dtype=self.dtype,
+ device=device)
+
def propose(
self,
# [num_tokens]
@@ -88,6 +98,7 @@ class EagleProposer:
next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
+ mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
@@ -128,14 +139,27 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
+ if self.is_multimodal_model:
+ input_ids = self.input_ids[:num_tokens]
+ inputs_embeds = self.model.get_input_embeddings(
+ input_ids,
+ multimodal_embeddings=mm_embeds or None,
+ )
+ self.inputs_embeds[:num_tokens] = inputs_embeds
+ inputs_embeds = self.inputs_embeds[:num_input_tokens]
+ input_ids = None
+ else:
+ inputs_embeds = None
+ input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
- self.input_ids[:num_input_tokens],
- self.positions[:num_input_tokens],
- self.hidden_states[:num_input_tokens],
+ input_ids=input_ids,
+ positions=self.positions[:num_input_tokens],
+ hidden_states=self.hidden_states[:num_input_tokens],
+ inputs_embeds=inputs_embeds,
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
@@ -218,15 +242,24 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
+ if self.is_multimodal_model:
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
+ self.inputs_embeds[:batch_size] = inputs_embeds
+ inputs_embeds = self.inputs_embeds[:input_batch_size]
+ input_ids = None
+ else:
+ inputs_embeds = None
+ input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
- self.input_ids[:input_batch_size],
- self.positions[:input_batch_size],
- self.hidden_states[:input_batch_size],
+ input_ids=input_ids,
+ positions=self.positions[:input_batch_size],
+ hidden_states=self.hidden_states[:input_batch_size],
+ inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
@@ -391,10 +424,18 @@ class EagleProposer:
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
+ if self.is_multimodal_model:
+ input_ids = None
+ inputs_embeds = self.inputs_embeds[:num_tokens]
+ else:
+ input_ids = self.input_ids[:num_tokens]
+ inputs_embeds = None
+
self.model(
- self.input_ids[:num_tokens],
- self.positions[:num_tokens],
- self.hidden_states[:num_tokens],
+ input_ids=input_ids,
+ positions=self.positions[:num_tokens],
+ hidden_states=self.hidden_states[:num_tokens],
+ inputs_embeds=inputs_embeds,
)
def validate_same_kv_cache_group(self,
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 28337a688e37f..360d24f39a1ae 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -1314,13 +1314,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
+ shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
- num_computed_tokens = req_state.num_computed_tokens
+ num_computed_tokens = \
+ req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
@@ -2298,6 +2300,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
+ mm_embeds = None
+ if self.is_multimodal_model:
+ mm_embeds = self._gather_mm_embeddings(scheduler_output,
+ shift_computed_tokens=1)
+
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
@@ -2305,6 +2312,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
next_token_ids=next_token_ids,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
+ mm_embeds=mm_embeds,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids