mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 20:07:09 +08:00
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/dbo-full-cudagraphs
This commit is contained in:
commit
e283eff060
@ -82,7 +82,7 @@ steps:
|
||||
- bash standalone_tests/python_only_compile.sh
|
||||
|
||||
- label: Basic Correctness Test # 30min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
@ -99,7 +99,7 @@ steps:
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Chunked Prefill Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_chunked_prefill
|
||||
@ -108,7 +108,7 @@ steps:
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
||||
- label: Core Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/core
|
||||
@ -209,7 +209,7 @@ steps:
|
||||
- pytest -v -s distributed/test_eplb_execute.py
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -228,7 +228,7 @@ steps:
|
||||
##### 1 GPU test #####
|
||||
|
||||
- label: Regression Test # 5min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
@ -280,7 +280,7 @@ steps:
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
|
||||
- label: Examples Test # 25min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
@ -305,7 +305,7 @@ steps:
|
||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/prefix_caching
|
||||
@ -314,7 +314,7 @@ steps:
|
||||
|
||||
|
||||
- label: Platform Tests (CUDA)
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/cuda
|
||||
@ -353,9 +353,10 @@ steps:
|
||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -368,7 +369,7 @@ steps:
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -377,7 +378,7 @@ steps:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Core Operation Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- tests/kernels/core
|
||||
@ -416,7 +417,7 @@ steps:
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
@ -424,7 +425,7 @@ steps:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
@ -437,7 +438,7 @@ steps:
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Model Executor Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
@ -447,7 +448,7 @@ steps:
|
||||
- pytest -v -s model_executor
|
||||
|
||||
- label: Benchmarks # 9min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
source_file_dependencies:
|
||||
- benchmarks/
|
||||
@ -455,7 +456,7 @@ steps:
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Benchmarks CLI Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/benchmarks/
|
||||
@ -494,7 +495,7 @@ steps:
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: Encoder Decoder tests # 5min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/encoder_decoder
|
||||
@ -502,7 +503,7 @@ steps:
|
||||
- pytest -v -s encoder_decoder
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -623,7 +624,7 @@ steps:
|
||||
|
||||
# This test is used only in PR development phase to test individual models and should never run on main
|
||||
- label: Custom Models Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
commands:
|
||||
- echo 'Testing custom models...'
|
||||
@ -658,7 +659,7 @@ steps:
|
||||
##### multi gpus test #####
|
||||
|
||||
- label: Distributed Comm Ops Test # 7min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
@ -755,7 +756,7 @@ steps:
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@ -776,7 +777,7 @@ steps:
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@ -790,7 +791,7 @@ steps:
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- label: LoRA TP Test (Distributed)
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
|
||||
@ -370,6 +370,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
fi
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
# !bang
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system dist/*.whl --verbose \
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20250724"
|
||||
ARG NIGHTLY_DATE="20250730"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
|
||||
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
|
||||
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
|
||||
|
||||
|
||||
@ -13,6 +13,38 @@ except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
|
||||
QUESTION = "What is the content of each image?"
|
||||
IMAGE_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
|
||||
]
|
||||
|
||||
|
||||
def get_custom_mm_prompts(num_prompts):
|
||||
prompts = []
|
||||
for url in IMAGE_URLS:
|
||||
prompts.append(
|
||||
[
|
||||
{"type": "image_url", "image_url": {"url": url}},
|
||||
{"type": "text", "text": QUESTION},
|
||||
]
|
||||
)
|
||||
if num_prompts > len(IMAGE_URLS):
|
||||
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
|
||||
|
||||
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
add_dataset_parser(parser)
|
||||
@ -35,6 +67,7 @@ def parse_args():
|
||||
parser.add_argument("--output-len", type=int, default=256)
|
||||
parser.add_argument("--model-dir", type=str, default=None)
|
||||
parser.add_argument("--eagle-dir", type=str, default=None)
|
||||
parser.add_argument("--custom-mm-prompts", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -44,14 +77,26 @@ def main():
|
||||
|
||||
model_dir = args.model_dir
|
||||
if args.model_dir is None:
|
||||
if args.custom_mm_prompts:
|
||||
raise ValueError(
|
||||
"custom_mm_prompts requires mm based models"
|
||||
"default llama3.1-8b-instruct is not mm based"
|
||||
"please specify model_dir to give a mm based model"
|
||||
)
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
args.custom_skip_chat_template = True
|
||||
|
||||
prompts = get_samples(args, tokenizer)
|
||||
# add_special_tokens is False to avoid adding bos twice when using chat templates
|
||||
prompt_ids = [
|
||||
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
|
||||
]
|
||||
if not args.custom_mm_prompts:
|
||||
prompts = get_samples(args, tokenizer)
|
||||
# add_special_tokens is False to avoid adding bos twice
|
||||
# when using chat templates
|
||||
prompt_ids = [
|
||||
tokenizer.encode(prompt.prompt, add_special_tokens=False)
|
||||
for prompt in prompts
|
||||
]
|
||||
else:
|
||||
prompts = get_custom_mm_prompts(args.num_prompts)
|
||||
|
||||
if args.method == "eagle" or args.method == "eagle3":
|
||||
eagle_dir = args.eagle_dir
|
||||
@ -85,10 +130,17 @@ def main():
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=16384,
|
||||
limit_mm_per_prompt={"image": 5},
|
||||
disable_chunked_mm_input=True,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
|
||||
if not args.custom_mm_prompts:
|
||||
outputs = llm.generate(
|
||||
prompt_token_ids=prompt_ids, sampling_params=sampling_params
|
||||
)
|
||||
else:
|
||||
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||
|
||||
# print the generated text
|
||||
if args.print_output:
|
||||
|
||||
@ -22,9 +22,7 @@ aiohttp==3.10.11
|
||||
aiohttp-cors==0.8.1
|
||||
# via ray
|
||||
aiosignal==1.3.1
|
||||
# via
|
||||
# aiohttp
|
||||
# ray
|
||||
# via aiohttp
|
||||
albucore==0.0.16
|
||||
# via terratorch
|
||||
albumentations==1.4.6
|
||||
@ -139,7 +137,7 @@ contourpy==1.3.0
|
||||
# via matplotlib
|
||||
cramjam==2.9.0
|
||||
# via fastparquet
|
||||
cupy-cuda12x==13.3.0
|
||||
cupy-cuda12x==13.5.1
|
||||
# via ray
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
@ -226,7 +224,6 @@ frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
# ray
|
||||
fsspec==2024.9.0
|
||||
# via
|
||||
# datasets
|
||||
@ -603,10 +600,18 @@ opencv-python-headless==4.11.0.86
|
||||
opentelemetry-api==1.35.0
|
||||
# via
|
||||
# mlflow-skinny
|
||||
# opentelemetry-exporter-prometheus
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
opentelemetry-exporter-prometheus==0.56b0
|
||||
# via ray
|
||||
opentelemetry-proto==1.36.0
|
||||
# via ray
|
||||
opentelemetry-sdk==1.35.0
|
||||
# via mlflow-skinny
|
||||
# via
|
||||
# mlflow-skinny
|
||||
# opentelemetry-exporter-prometheus
|
||||
# ray
|
||||
opentelemetry-semantic-conventions==0.56b0
|
||||
# via opentelemetry-sdk
|
||||
packaging==24.2
|
||||
@ -697,7 +702,9 @@ pqdm==0.2.0
|
||||
pretrainedmodels==0.7.4
|
||||
# via segmentation-models-pytorch
|
||||
prometheus-client==0.22.0
|
||||
# via ray
|
||||
# via
|
||||
# opentelemetry-exporter-prometheus
|
||||
# ray
|
||||
propcache==0.2.0
|
||||
# via yarl
|
||||
proto-plus==1.26.1
|
||||
@ -707,6 +714,7 @@ protobuf==5.28.3
|
||||
# google-api-core
|
||||
# googleapis-common-protos
|
||||
# mlflow-skinny
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
# ray
|
||||
# tensorboardx
|
||||
@ -854,7 +862,7 @@ rasterio==1.4.3
|
||||
# rioxarray
|
||||
# terratorch
|
||||
# torchgeo
|
||||
ray==2.43.0
|
||||
ray==2.48.0
|
||||
# via -r requirements/test.in
|
||||
redis==5.2.0
|
||||
# via tensorizer
|
||||
|
||||
@ -19,8 +19,8 @@ nixl==0.3.0
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.9.0.dev20250724
|
||||
torchvision==0.24.0.dev20250724
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
|
||||
torch==2.9.0.dev20250730
|
||||
torchvision==0.24.0.dev20250730
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
|
||||
|
||||
|
||||
203
setup.py
203
setup.py
@ -282,10 +282,69 @@ class cmake_build_ext(build_ext):
|
||||
self.copy_file(file, dst_file)
|
||||
|
||||
|
||||
class repackage_wheel(build_ext):
|
||||
class precompiled_wheel_utils:
|
||||
"""Extracts libraries and other files from an existing wheel."""
|
||||
|
||||
def get_base_commit_in_main_branch(self) -> str:
|
||||
@staticmethod
|
||||
def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict:
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
temp_dir = None
|
||||
try:
|
||||
if not os.path.isfile(wheel_url_or_path):
|
||||
wheel_filename = wheel_url_or_path.split("/")[-1]
|
||||
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
||||
wheel_path = os.path.join(temp_dir, wheel_filename)
|
||||
print(f"Downloading wheel from {wheel_url_or_path} "
|
||||
f"to {wheel_path}")
|
||||
from urllib.request import urlretrieve
|
||||
urlretrieve(wheel_url_or_path, filename=wheel_path)
|
||||
else:
|
||||
wheel_path = wheel_url_or_path
|
||||
print(f"Using existing wheel at {wheel_path}")
|
||||
|
||||
package_data_patch = {}
|
||||
|
||||
with zipfile.ZipFile(wheel_path) as wheel:
|
||||
files_to_copy = [
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
]
|
||||
|
||||
compiled_regex = re.compile(
|
||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
||||
file_members = list(
|
||||
filter(lambda x: x.filename in files_to_copy,
|
||||
wheel.filelist))
|
||||
file_members += list(
|
||||
filter(lambda x: compiled_regex.match(x.filename),
|
||||
wheel.filelist))
|
||||
|
||||
for file in file_members:
|
||||
print(f"[extract] {file.filename}")
|
||||
target_path = os.path.join(".", file.filename)
|
||||
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
||||
with wheel.open(file.filename) as src, open(
|
||||
target_path, "wb") as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
|
||||
pkg = os.path.dirname(file.filename).replace("/", ".")
|
||||
package_data_patch.setdefault(pkg, []).append(
|
||||
os.path.basename(file.filename))
|
||||
|
||||
return package_data_patch
|
||||
finally:
|
||||
if temp_dir is not None:
|
||||
print(f"Removing temporary directory {temp_dir}")
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@staticmethod
|
||||
def get_base_commit_in_main_branch() -> str:
|
||||
# Force to use the nightly wheel. This is mainly used for CI testing.
|
||||
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
|
||||
return "nightly"
|
||||
@ -334,115 +393,6 @@ class repackage_wheel(build_ext):
|
||||
"wheel may not be compatible with your dev branch: %s", err)
|
||||
return "nightly"
|
||||
|
||||
def run(self) -> None:
|
||||
assert _is_cuda(
|
||||
), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
||||
|
||||
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
|
||||
if wheel_location is None:
|
||||
base_commit = self.get_base_commit_in_main_branch()
|
||||
wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
# Fallback to nightly wheel if latest commit wheel is unavailable,
|
||||
# in this rare case, the nightly release CI hasn't finished on main.
|
||||
if not is_url_available(wheel_location):
|
||||
wheel_location = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
|
||||
import zipfile
|
||||
|
||||
if os.path.isfile(wheel_location):
|
||||
wheel_path = wheel_location
|
||||
print(f"Using existing wheel={wheel_path}")
|
||||
else:
|
||||
# Download the wheel from a given URL, assume
|
||||
# the filename is the last part of the URL
|
||||
wheel_filename = wheel_location.split("/")[-1]
|
||||
|
||||
import tempfile
|
||||
|
||||
# create a temporary directory to store the wheel
|
||||
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
||||
wheel_path = os.path.join(temp_dir, wheel_filename)
|
||||
print(f"Downloading wheel from {wheel_location} to {wheel_path}")
|
||||
from urllib.request import urlretrieve
|
||||
try:
|
||||
urlretrieve(wheel_location, filename=wheel_path)
|
||||
except Exception as e:
|
||||
from setuptools.errors import SetupError
|
||||
raise SetupError(
|
||||
f"Failed to get vLLM wheel from {wheel_location}") from e
|
||||
|
||||
# Set the dist_dir for Docker build context
|
||||
dist_dir = ("/workspace/dist"
|
||||
if envs.VLLM_DOCKER_BUILD_CONTEXT else "dist")
|
||||
os.makedirs(dist_dir, exist_ok=True)
|
||||
|
||||
# Extract only necessary compiled .so files from precompiled wheel
|
||||
with zipfile.ZipFile(wheel_path) as wheel:
|
||||
# Get version from METADATA (optional, mostly useful for logging)
|
||||
metadata_file = next((n for n in wheel.namelist()
|
||||
if n.endswith(".dist-info/METADATA")), None)
|
||||
if not metadata_file:
|
||||
raise RuntimeError(
|
||||
"Could not find METADATA in precompiled wheel.")
|
||||
metadata = wheel.read(metadata_file).decode()
|
||||
version_line = next((line for line in metadata.splitlines()
|
||||
if line.startswith("Version: ")), None)
|
||||
if not version_line:
|
||||
raise RuntimeError(
|
||||
"Could not determine version from METADATA.")
|
||||
version = version_line.split(": ")[1].strip()
|
||||
|
||||
print(f"Extracting precompiled kernels from vLLM wheel version: "
|
||||
f"{version}")
|
||||
|
||||
# List of compiled shared objects to extract
|
||||
files_to_copy = [
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
]
|
||||
|
||||
file_members = list(
|
||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist))
|
||||
compiled_regex = re.compile(
|
||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
||||
file_members += list(
|
||||
filter(lambda x: compiled_regex.match(x.filename),
|
||||
wheel.filelist))
|
||||
|
||||
for file in file_members:
|
||||
print(f"Extracting and including {file.filename} "
|
||||
"from existing wheel")
|
||||
package_name = os.path.dirname(file.filename).replace("/", ".")
|
||||
file_name = os.path.basename(file.filename)
|
||||
|
||||
if package_name not in package_data:
|
||||
package_data[package_name] = []
|
||||
|
||||
output_base = (dist_dir
|
||||
if envs.VLLM_DOCKER_BUILD_CONTEXT else ".")
|
||||
target_path = os.path.join(output_base, file.filename)
|
||||
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
||||
with wheel.open(file.filename) as src, open(target_path,
|
||||
"wb") as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
|
||||
package_data[package_name].append(file_name)
|
||||
|
||||
# Copy wheel into dist dir for Docker to consume (e.g., via --mount)
|
||||
if envs.VLLM_DOCKER_BUILD_CONTEXT:
|
||||
arch_tag = "cp38-abi3-manylinux1_x86_64"
|
||||
corrected_wheel_name = f"vllm-{version}-{arch_tag}.whl"
|
||||
final_wheel_path = os.path.join(dist_dir, corrected_wheel_name)
|
||||
|
||||
print(
|
||||
"Docker build context detected, copying precompiled wheel to "
|
||||
f"{final_wheel_path}")
|
||||
shutil.copy2(wheel_path, final_wheel_path)
|
||||
|
||||
|
||||
def _no_device() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "empty"
|
||||
@ -676,16 +626,37 @@ package_data = {
|
||||
]
|
||||
}
|
||||
|
||||
# If using precompiled, extract and patch package_data (in advance of setup)
|
||||
if envs.VLLM_USE_PRECOMPILED:
|
||||
assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
||||
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
|
||||
if wheel_location is not None:
|
||||
wheel_url = wheel_location
|
||||
else:
|
||||
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
from urllib.request import urlopen
|
||||
try:
|
||||
with urlopen(wheel_url) as resp:
|
||||
if resp.status != 200:
|
||||
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
except Exception as e:
|
||||
print(f"[warn] Falling back to nightly wheel: {e}")
|
||||
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
|
||||
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
|
||||
wheel_url)
|
||||
for pkg, files in patch.items():
|
||||
package_data.setdefault(pkg, []).extend(files)
|
||||
|
||||
if _no_device():
|
||||
ext_modules = []
|
||||
|
||||
if not ext_modules:
|
||||
if not ext_modules or envs.VLLM_USE_PRECOMPILED:
|
||||
# Disable build_ext when using precompiled wheel
|
||||
cmdclass = {}
|
||||
else:
|
||||
cmdclass = {
|
||||
"build_ext":
|
||||
repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
|
||||
}
|
||||
cmdclass = {"build_ext": cmake_build_ext}
|
||||
|
||||
setup(
|
||||
# static metadata should rather go in pyproject.toml
|
||||
|
||||
@ -7,22 +7,26 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||
ModelConfig, PassConfig, VllmConfig)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
GroupShape, QuantFP8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
from ..utils import has_module_attribute, multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.quant_fp8 = QuantFP8(static=True,
|
||||
group_shape=GroupShape.PER_TENSOR)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size),
|
||||
dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
torch.ops._C.static_scaled_fp8_quant(self.output,
|
||||
norm_output.contiguous(),
|
||||
self.scale)
|
||||
return self.output, residual_output
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size),
|
||||
dtype=torch.float32)
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(token_num, 128)
|
||||
scale_n = hidden_size // 16
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||
dtype=torch.int32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
|
||||
self.output_scale, self.scale)
|
||||
return self.output, residual_output, self.output_scale
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.scaled_fp4_quant.default
|
||||
]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
|
||||
@pytest.mark.parametrize("test_model", [
|
||||
TestAllReduceRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [4096])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(not find_spec("flashinfer"),
|
||||
reason="flashinfer is not installed")
|
||||
@pytest.mark.skipif(not current_platform.is_device_capability(100),
|
||||
reason="Only test on SM100")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||
reason="flashinfer is not found or flashinfer "
|
||||
"is not compiled with trtllm_allreduce_fusion")
|
||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
num_processes = 2
|
||||
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||
and not current_platform.has_device_capability(100)):
|
||||
pytest.skip("Skip as nvfp4 is only supported on "
|
||||
"devices with compute capability 10.0 (Blackwell)")
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(fn,
|
||||
@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm"],
|
||||
compile_sizes=[2, 4, 8]))
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"]))
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True)
|
||||
enable_fi_allreduce_fusion=True, enable_noop=True)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
seed=42)
|
||||
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
backend = TestBackend(all_reduce_fusion_pass)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
requires_grad=False)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
||||
requires_grad=False)
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
residual = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states, residual)
|
||||
|
||||
@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||
"Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
|
||||
trust_remote_code=True),
|
||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||
@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||
trust_remote_code=True),
|
||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
||||
"Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
@ -974,3 +975,14 @@ def get_client_text_logprob_generations(
|
||||
return [(text_generations, text,
|
||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||
for completion in completions for x in completion.choices]
|
||||
|
||||
|
||||
def has_module_attribute(module_name, attribute_name):
|
||||
"""
|
||||
Helper function to check if a module has a specific attribute.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
return hasattr(module, attribute_name)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
196
tests/v1/attention/test_chunked_local_attention.py
Normal file
196
tests/v1/attention/test_chunked_local_attention.py
Normal file
@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
make_local_attention_virtual_batches)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalAttentionTestData:
|
||||
# Input parameters
|
||||
batch_spec: BatchSpec
|
||||
attn_chunk_size: int
|
||||
block_size: int
|
||||
# Expected return values
|
||||
expected_q_seqlens: list[int]
|
||||
expected_k_seqlens: list[int]
|
||||
expected_local_block_table: list[list[int]]
|
||||
|
||||
|
||||
test_data_list = [
|
||||
# Same as example in docstring of make_local_attention_virtual_batches
|
||||
# except block table has 9 columns instead of 10
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[4, 10, 5],
|
||||
seq_lens=[6, 17, 9],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[2, 2, 1, 4, 4, 1, 4, 1],
|
||||
expected_k_seqlens=[4, 2, 4, 4, 4, 1, 4, 1],
|
||||
# 2 pages per local branch
|
||||
# (chunk size 4 // block size 2)
|
||||
expected_local_block_table=[
|
||||
[0, 1], # local-batch 0, (batch 0, starting from k[0])
|
||||
[2, 3], # local-batch 1, (batch 0, starting from k[4])
|
||||
[11, 12], # local-batch 2, (batch 1, starting from k[4])
|
||||
[13, 14], # local-batch 3, (batch 1, starting from k[8])
|
||||
[15, 16], # local-batch 4, (batch 1, starting from k[12])
|
||||
[17, 17], # local-batch 5, (batch 1, starting from k[16])
|
||||
[20, 21], # local-batch 6, (batch 2, starting from k[4])
|
||||
[22, 23], # local-batch 7, (batch 2, starting from k[8])
|
||||
]),
|
||||
# Case where block indices are not clipped to block table ncols-1
|
||||
# because tokens_in_last_block == attn_chunk_size
|
||||
LocalAttentionTestData(batch_spec=BatchSpec(
|
||||
query_lens=[8],
|
||||
seq_lens=[12],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[4, 4],
|
||||
expected_k_seqlens=[4, 4],
|
||||
expected_local_block_table=[
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
]),
|
||||
# Case where all kv_seq positions are involved in attn
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[7],
|
||||
# 10 - 7 = 3 previously computed tokens
|
||||
seq_lens=[10],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[1, 4, 2],
|
||||
expected_k_seqlens=[4, 4, 2],
|
||||
expected_local_block_table=[
|
||||
[0, 1],
|
||||
[2, 3],
|
||||
[4, 4],
|
||||
]),
|
||||
# Case where attn_chunk_size > kv_seq_len
|
||||
# so no extra mini virtual batches are created
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[4],
|
||||
seq_lens=[6],
|
||||
),
|
||||
# Larger than kv_seq_len
|
||||
attn_chunk_size=10,
|
||||
block_size=2,
|
||||
# No change to q_seqlens and k_seqlens
|
||||
expected_q_seqlens=[4],
|
||||
expected_k_seqlens=[6],
|
||||
# In this case, we only need a block-table like:
|
||||
# block_table = [ [0, 1, 2] ] # 1 batch, 3 pages
|
||||
# But we need to pad it to 5 pages per local batch
|
||||
# because currently the pages_per_local_batch
|
||||
# is calculated as (attn_chunk_size // block_size)
|
||||
expected_local_block_table=[
|
||||
[0, 1, 2, 2, 2],
|
||||
]),
|
||||
# Block size equal to chunk size
|
||||
# Expect single page per batch in local batch table
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[6, 6],
|
||||
seq_lens=[8, 8],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=4,
|
||||
expected_q_seqlens=[2, 4, 2, 4],
|
||||
expected_k_seqlens=[4, 4, 4, 4],
|
||||
# Initial block table = [
|
||||
# [0, 1], < batch 0
|
||||
# [2, 3], < batch 1
|
||||
# ]
|
||||
expected_local_block_table=[
|
||||
[0], # local-batch 0, (batch 0, starting from k[0])
|
||||
[1], # local-batch 1, (batch 0, starting from k[4])
|
||||
[2], # local-batch 1, (batch 0, starting from k[0])
|
||||
[3], # local-batch 1, (batch 0, starting from k[4])
|
||||
]),
|
||||
# Case where query falls in the second attention chunk
|
||||
# k_toks > 0 1 2 3 4
|
||||
# q_toks v _____________
|
||||
# 0 | 1
|
||||
# 1 | 1 1
|
||||
# 2 | 1 1 1
|
||||
# 3 | 1 1 1 1
|
||||
# 4 | 1
|
||||
# where tokens 0,1,2,3 have been pre-computed
|
||||
LocalAttentionTestData(batch_spec=BatchSpec(
|
||||
query_lens=[1],
|
||||
seq_lens=[5],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[1],
|
||||
expected_k_seqlens=[1],
|
||||
expected_local_block_table=[
|
||||
[2, 2],
|
||||
]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_data_list)
|
||||
def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||
device = torch.device("cuda:0")
|
||||
batch_spec = test_data.batch_spec
|
||||
attn_chunk_size = test_data.attn_chunk_size
|
||||
block_size = test_data.block_size
|
||||
expected_q_seqlens = test_data.expected_q_seqlens
|
||||
expected_k_seqlens = test_data.expected_k_seqlens
|
||||
expected_local_block_table = test_data.expected_local_block_table
|
||||
|
||||
# Create common attention metadata
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size,
|
||||
device,
|
||||
# Use torch.arange instead of torch.randint so we can assert on
|
||||
# block table tensor values. The block table will have shape
|
||||
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
|
||||
# aranged from 0 to cdiv(max_seq_len, block_size)-1
|
||||
arange_block_indices=True,
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = make_local_attention_virtual_batches(attn_chunk_size,
|
||||
common_attn_metadata,
|
||||
block_size)
|
||||
|
||||
# Convert to numpy for easier comparison
|
||||
actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy())
|
||||
actual_k_seqlens = result.seq_lens_cpu.numpy()
|
||||
|
||||
# Check that all query lengths are less than or equal to attn_chunk_size
|
||||
assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens)
|
||||
# Check that all key lengths are less than or equal to attn_chunk_size
|
||||
assert all(k_len <= attn_chunk_size for k_len in actual_k_seqlens)
|
||||
# Check that the total number of query tokens is preserved
|
||||
assert sum(actual_q_seqlens) == sum(batch_spec.query_lens)
|
||||
|
||||
# Verify results
|
||||
np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens)
|
||||
np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens)
|
||||
|
||||
expected_block_table_tensor =\
|
||||
torch.tensor(expected_local_block_table,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
print(f"Expected block table:\n{expected_block_table_tensor}")
|
||||
print(f"Actual block table:\n{result.block_table_tensor}")
|
||||
|
||||
torch.testing.assert_close(result.block_table_tensor,
|
||||
expected_block_table_tensor)
|
||||
@ -40,7 +40,8 @@ def create_common_attn_metadata(
|
||||
batch_spec: BatchSpec,
|
||||
block_size: int,
|
||||
device: torch.device,
|
||||
max_block_idx: int = 1000) -> CommonAttentionMetadata:
|
||||
max_block_idx: int = 1000,
|
||||
arange_block_indices: bool = False) -> CommonAttentionMetadata:
|
||||
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
|
||||
# Create query start locations
|
||||
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
|
||||
@ -65,19 +66,28 @@ def create_common_attn_metadata(
|
||||
]
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
|
||||
# Create block table (random for testing)
|
||||
# Create block table and slot mapping
|
||||
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||
block_table_tensor = torch.randint(0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# Create slot mapping
|
||||
slot_mapping = torch.randint(0,
|
||||
max_block_idx, (num_tokens, ),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
if arange_block_indices:
|
||||
num_blocks = batch_spec.batch_size * max_blocks
|
||||
block_table_tensor = torch.arange(num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=device).view(
|
||||
batch_spec.batch_size,
|
||||
max_blocks)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device).view(num_tokens)
|
||||
else:
|
||||
block_table_tensor = torch.randint(0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
slot_mapping = torch.randint(0,
|
||||
max_block_idx, (num_tokens, ),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
|
||||
# Calculate max query length
|
||||
max_query_len = max(batch_spec.query_lens)
|
||||
|
||||
@ -3,29 +3,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
if mm_enabled:
|
||||
prompt_types.append("mm")
|
||||
num_prompts = 100
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
print(f"Prompt types: {random_prompt_type_choices}")
|
||||
|
||||
# Generate a mixed batch of prompts, some of which can be easily
|
||||
# predicted by n-gram matching and some which likely cannot.
|
||||
for kind in random_prompt_type_choices:
|
||||
word_choices = ["test", "temp", "hello", "where"]
|
||||
word = random.choice(word_choices)
|
||||
prompt: Union[str, list[dict[str, Any]]] = ""
|
||||
if kind == "repeat":
|
||||
prompt = f"""
|
||||
please repeat the word '{word}' 10 times.
|
||||
@ -38,6 +43,21 @@ def test_prompts():
|
||||
uses the word {word} at least once.
|
||||
give no other output than that simple sentence without quotes.
|
||||
"""
|
||||
elif kind == "mm":
|
||||
placeholders = [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}]
|
||||
prompt = [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The meaning of the image is"
|
||||
},
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
prompts.append([{"role": "user", "content": prompt}])
|
||||
@ -57,7 +77,6 @@ def model_name():
|
||||
|
||||
def test_ngram_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
@ -67,6 +86,7 @@ def test_ngram_correctness(
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
@ -103,23 +123,32 @@ def test_ngram_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_setup", [
|
||||
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
|
||||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"], [
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
|
||||
@ -37,6 +37,8 @@ logger = init_logger(__name__)
|
||||
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class BasePattern:
|
||||
@ -394,7 +396,7 @@ if flashinfer_comm is not None:
|
||||
# Max size of the input tensor per world size
|
||||
# to use flashinfer fused allreduce
|
||||
_FI_MAX_SIZES = {
|
||||
2: MiB, # 1MB
|
||||
2: 64 * MiB, # 64MB
|
||||
4: MiB, # 1MB
|
||||
6: MiB // 2, # 512KB
|
||||
8: MiB // 2, # 512KB
|
||||
@ -414,9 +416,13 @@ if flashinfer_comm is not None:
|
||||
trigger_completion_at_end: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
fuse_rms_quant: bool,
|
||||
norm_out: Optional[torch.Tensor] = None,
|
||||
quant_out: Optional[torch.Tensor] = None,
|
||||
scale_out: Optional[torch.Tensor] = None,
|
||||
scale_factor: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
@ -425,7 +431,6 @@ if flashinfer_comm is not None:
|
||||
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
||||
max_fusion_size,
|
||||
)
|
||||
|
||||
if use_flashinfer:
|
||||
assert (_FI_WORKSPACE_TENSOR is not None
|
||||
), "Flashinfer must be enabled when using flashinfer"
|
||||
@ -455,37 +460,65 @@ if flashinfer_comm is not None:
|
||||
use_oneshot=True,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||
kARResidualRMSNorm,
|
||||
pattern_code=pattern_code,
|
||||
allreduce_out=None,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
layout_code=None,
|
||||
scale_factor=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
else:
|
||||
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
|
||||
rms_gamma, rms_eps)
|
||||
if (scale_factor is not None and scale_out is None
|
||||
and fuse_rms_quant):
|
||||
# Do fused rms norm static fp8 quant fused op
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||
quant_out, allreduce_out, residual, rms_gamma,
|
||||
scale_factor, rms_eps)
|
||||
else:
|
||||
torch.ops._C.rms_norm_static_fp8_quant(
|
||||
quant_out, allreduce_out, rms_gamma, scale_factor,
|
||||
rms_eps)
|
||||
else:
|
||||
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
|
||||
rms_eps)
|
||||
allreduce_in.copy_(allreduce_out)
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
|
||||
rms_gamma, rms_eps)
|
||||
norm_out = allreduce_out
|
||||
else:
|
||||
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
|
||||
rms_eps)
|
||||
if scale_factor is not None:
|
||||
if scale_out is not None:
|
||||
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
|
||||
scale_out, scale_factor)
|
||||
else:
|
||||
torch.ops._C.static_scaled_fp8_quant(
|
||||
quant_out, norm_out, scale_factor)
|
||||
if scale_factor is None or norm_out is not None:
|
||||
# we need to return allreduce outpput
|
||||
# in cases of non quant fused AR + RMS norm
|
||||
# and fused AR + RMS norm + quant without fused add
|
||||
allreduce_in.copy_(allreduce_out)
|
||||
|
||||
def call_trtllm_fused_allreduce_norm_fake(
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_rank: int,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
trigger_completion_at_end: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
norm_out: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_rank: int,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
trigger_completion_at_end: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
fuse_rms_quant: bool,
|
||||
norm_out: Optional[torch.Tensor] = None,
|
||||
quant_out: Optional[torch.Tensor] = None,
|
||||
scale_out: Optional[torch.Tensor] = None,
|
||||
scale_factor: Optional[torch.Tensor] = None) -> None:
|
||||
pass
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -495,6 +528,8 @@ if flashinfer_comm is not None:
|
||||
"allreduce_in",
|
||||
"residual",
|
||||
"norm_out",
|
||||
"quant_out",
|
||||
"scale_out",
|
||||
],
|
||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams:
|
||||
world_size: int,
|
||||
use_fp32_lamport: bool = False,
|
||||
max_token_num: int = 1024,
|
||||
fuse_rms_quant: bool = False,
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams:
|
||||
self.fp32_acc = True
|
||||
self.use_oneshot = False
|
||||
self.max_token_num = max_token_num
|
||||
self.fuse_rms_quant = fuse_rms_quant
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self):
|
||||
return {
|
||||
@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams:
|
||||
"trigger_completion_at_end": self.trigger_completion_at_end,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
"max_token_num": self.max_token_num,
|
||||
"fuse_rms_quant": self.fuse_rms_quant,
|
||||
}
|
||||
|
||||
|
||||
class AllReduceRMSNORMPattern(BasePattern):
|
||||
class AllReduceRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm before attn in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern):
|
||||
|
||||
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
|
||||
weight: torch.Tensor):
|
||||
all_reduce_output = tensor_model_parallel_all_reduce(input)
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=rms_result,
|
||||
input=all_reduce_output,
|
||||
input=allreduce_output,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
return rms[1], all_reduce_output
|
||||
# rms_result, allreduce_output
|
||||
return rms[1], allreduce_output
|
||||
|
||||
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
|
||||
weight: torch.Tensor):
|
||||
residual = torch.zeros_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=rms_result,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||
kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# rms_result, allreduce_in
|
||||
return allreduce[3], allreduce[1]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern):
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
|
||||
def pattern(residual: torch.Tensor, input: torch.Tensor,
|
||||
weight: torch.Tensor):
|
||||
all_reduce_output = tensor_model_parallel_all_reduce(input)
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=all_reduce_output,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
# input, residual
|
||||
return rms[1], rms[2]
|
||||
|
||||
def replacement(residual: torch.Tensor, input: torch.Tensor,
|
||||
weight: torch.Tensor):
|
||||
allreduce = auto_functionalized(
|
||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
norm_out=None,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||
kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# allreduce_in, residual
|
||||
return allreduce[1], allreduce[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
allreduce_params: FlashInferFusedAllReduceParams):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def get_inputs():
|
||||
input = torch.zeros([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=self.quant_dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
|
||||
result=rmsnorm_result,
|
||||
input=all_reduce,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP,
|
||||
result=quant_result,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return quant_out_tuple[1], all_reduce
|
||||
|
||||
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
|
||||
quant_result: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
residual = torch.zeros_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=quant_result,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return allreduce[4], allreduce[1]
|
||||
|
||||
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
allreduce_params: FlashInferFusedAllReduceParams):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def get_inputs():
|
||||
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.quant_dtype)
|
||||
scale = torch.empty([1, 1],
|
||||
device=self.device,
|
||||
dtype=torch.float32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
weight,
|
||||
scale,
|
||||
]
|
||||
|
||||
def pattern(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
|
||||
fused_add_rmsnorm_out_tuple = \
|
||||
auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP8_QUANT_OP,
|
||||
result=quant_result,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
|
||||
|
||||
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=quant_result,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# # quant_out, rms_norm_residual
|
||||
return allreduce[4], allreduce[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
allreduce_params: FlashInferFusedAllReduceParams):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def get_inputs():
|
||||
input = torch.empty([1, 16, 16],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
|
||||
rmsnorm_result = torch.empty([1, 16, 16],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8),
|
||||
device=self.device,
|
||||
dtype=torch.uint8)
|
||||
input_global_scale = torch.empty([1, 1],
|
||||
device=self.device,
|
||||
dtype=torch.float32)
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4],
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
return [
|
||||
input, rmsnorm_result, quant_result, weight,
|
||||
input_global_scale, output_scale
|
||||
]
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
|
||||
result=rmsnorm_result,
|
||||
input=all_reduce,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
|
||||
|
||||
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
|
||||
quant_result: torch.Tensor, weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor):
|
||||
residual = torch.zeros_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return allreduce[4], allreduce[1], allreduce[5]
|
||||
|
||||
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
allreduce_params: FlashInferFusedAllReduceParams):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def get_inputs():
|
||||
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([16, 16],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
weight = torch.empty([16, 16],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8),
|
||||
device=self.device,
|
||||
dtype=torch.uint8)
|
||||
input_global_scale = torch.empty([1, 1],
|
||||
device=self.device,
|
||||
dtype=torch.float32)
|
||||
output_scale = torch.empty([128, 4],
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
output_scale,
|
||||
weight,
|
||||
input_global_scale,
|
||||
]
|
||||
|
||||
def pattern(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||
input: torch.Tensor, output_scale: torch.Tensor,
|
||||
weight: torch.Tensor, input_global_scale: torch.Tensor):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
|
||||
fused_add_rmsnorm_out_tuple = \
|
||||
auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[
|
||||
2], quant_out_tuple[2]
|
||||
|
||||
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||
input: torch.Tensor, output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor):
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# quant_out, rms_norm_residual, output_scale
|
||||
return allreduce[4], allreduce[2], allreduce[5]
|
||||
|
||||
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllReduceFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
|
||||
max_num_token = min(
|
||||
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
|
||||
(self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
|
||||
config.compilation_config.pass_config.
|
||||
fi_allreduce_fusion_max_token_num)
|
||||
self.ipc_handles, workspace_tensor = (
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
tp_rank=rank,
|
||||
tp_size=self.tp_size,
|
||||
max_token_num=config.compilation_config.pass_config.
|
||||
fi_allreduce_fusion_max_token_num,
|
||||
max_token_num=max_num_token,
|
||||
hidden_dim=self.hidden_dim,
|
||||
group=self.group,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
rank=rank,
|
||||
world_size=self.tp_size,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
max_token_num=config.compilation_config.pass_config.
|
||||
fi_allreduce_fusion_max_token_num,
|
||||
)
|
||||
max_token_num=max_num_token,
|
||||
# fuse rms norm static fp8 quant fused op
|
||||
# in fallback path, when we don't use flashinfer
|
||||
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AllReduceRMSNORMPattern(
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
if current_platform.has_device_capability(100):
|
||||
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
self.disabled = False
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
if self.disabled:
|
||||
return
|
||||
if flashinfer_comm is not None:
|
||||
flashinfer_comm.trtllm_destroy_ipc_workspace(
|
||||
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
||||
self.ipc_handles, self.group)
|
||||
|
||||
@ -108,7 +108,7 @@ def support_torch_compile(
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
|
||||
@ -4062,7 +4062,7 @@ class PassConfig:
|
||||
"""Whether to enable async TP."""
|
||||
enable_fi_allreduce_fusion: bool = False
|
||||
"""Whether to enable flashinfer allreduce fusion."""
|
||||
fi_allreduce_fusion_max_token_num: int = 1024
|
||||
fi_allreduce_fusion_max_token_num: int = 16384
|
||||
"""Max number of tokens to used in flashinfer allreduce fusion."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||
from .step3_tool_parser import Step3ToolParser
|
||||
from .xlam_tool_parser import xLAMToolParser
|
||||
|
||||
__all__ = [
|
||||
@ -40,4 +41,5 @@ __all__ = [
|
||||
"HunyuanA13BToolParser",
|
||||
"Glm4MoeModelToolParser",
|
||||
"Qwen3CoderToolParser",
|
||||
"Step3ToolParser",
|
||||
]
|
||||
|
||||
296
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
Normal file
296
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
Normal file
@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["step3"])
|
||||
class Step3ToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for a model that uses a specific XML-like format for tool calls.
|
||||
This version uses a robust, stateful, cursor-based streaming parser and
|
||||
consolidates tool arguments into a single message.
|
||||
"""
|
||||
|
||||
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
|
||||
TOOL_CALLS_END = "<|tool_calls_end|>"
|
||||
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
|
||||
TOOL_CALL_END = "<|tool_call_end|>"
|
||||
TOOL_SEP = "<|tool_sep|>"
|
||||
SPECIAL_TOKENS = [
|
||||
TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
|
||||
]
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
# Explicit state flags for robust streaming
|
||||
self.tool_block_started = False
|
||||
self.tool_block_finished = False
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def _parse_steptml_invoke(
|
||||
action_text: str
|
||||
) -> tuple[Optional[str], Optional[dict[str, str]]]:
|
||||
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">',
|
||||
action_text)
|
||||
if not func_name_match:
|
||||
return None, None
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
params: dict[str, str] = {}
|
||||
param_matches = re.findall(
|
||||
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
|
||||
action_text)
|
||||
for name, value in param_matches:
|
||||
params[name] = value.strip()
|
||||
return func_name, params
|
||||
|
||||
def _cast_arguments(
|
||||
self,
|
||||
func_name: str,
|
||||
params: dict[str, Any],
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict[str, Any]:
|
||||
for tool in request.tools or []:
|
||||
if tool.function.name == func_name:
|
||||
schema = tool.function.parameters or {}
|
||||
properties = schema.get("properties", {})
|
||||
for key, value in params.items():
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
prop = properties.get(key, {})
|
||||
typ = prop.get("type")
|
||||
if typ == "string":
|
||||
params[key] = value.strip()
|
||||
elif typ == "integer":
|
||||
with contextlib.suppress(ValueError):
|
||||
params[key] = int(value)
|
||||
elif typ == "number":
|
||||
with contextlib.suppress(ValueError):
|
||||
params[key] = float(value)
|
||||
elif typ == "boolean":
|
||||
lower_val = value.lower()
|
||||
params[key] = lower_val == "true" if lower_val in (
|
||||
"true", "false") else value
|
||||
elif typ == "null":
|
||||
params[key] = None if value.lower(
|
||||
) == "null" else value
|
||||
break
|
||||
return params
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# The main loop processes the stream from the last known position.
|
||||
while True:
|
||||
if self.position >= len(current_text):
|
||||
return None # We've processed the entire stream.
|
||||
|
||||
unprocessed_text = current_text[self.position:]
|
||||
|
||||
# STATE: After all tools are done, all subsequent text is content.
|
||||
if self.tool_block_finished:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=unprocessed_text)
|
||||
|
||||
# STATE: Before the tool block has started.
|
||||
if not self.tool_block_started:
|
||||
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
|
||||
self.position += len(self.TOOL_CALLS_BEGIN)
|
||||
self.tool_block_started = True
|
||||
continue # Token consumed, re-loop.
|
||||
|
||||
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
|
||||
if start_pos == -1:
|
||||
if self.TOOL_CALLS_BEGIN.startswith(
|
||||
unprocessed_text.strip()) and unprocessed_text:
|
||||
return None # It's a prefix, wait.
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=unprocessed_text)
|
||||
else:
|
||||
content = unprocessed_text[:start_pos]
|
||||
self.position += len(content)
|
||||
return DeltaMessage(content=content)
|
||||
|
||||
# STATE: Inside the main tool block.
|
||||
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
|
||||
unprocessed_text = unprocessed_text.lstrip()
|
||||
self.position += offset
|
||||
|
||||
if unprocessed_text.startswith(self.TOOL_CALLS_END):
|
||||
self.position += len(self.TOOL_CALLS_END)
|
||||
self.tool_block_finished = True
|
||||
self.current_tool_id = -1
|
||||
continue
|
||||
|
||||
# Check if we are between tool calls.
|
||||
tool_finished = (
|
||||
self.current_tool_id != -1 and
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("finished"))
|
||||
if self.current_tool_id == -1 or tool_finished:
|
||||
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
|
||||
self.position += len(self.TOOL_CALL_BEGIN)
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
else:
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id]["finished"] = False
|
||||
continue
|
||||
|
||||
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
|
||||
return None
|
||||
|
||||
# STATE: Parsing an active tool call.
|
||||
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("finished", False):
|
||||
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
|
||||
if end_tool_pos == -1:
|
||||
tool_body = unprocessed_text
|
||||
else:
|
||||
tool_body = unprocessed_text[:end_tool_pos]
|
||||
|
||||
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
|
||||
tool_body):
|
||||
return None
|
||||
|
||||
function_name, arguments = self._parse_steptml_invoke(
|
||||
tool_body)
|
||||
if not function_name:
|
||||
return None
|
||||
|
||||
tool_call_arr = {
|
||||
"name": function_name,
|
||||
"parameters": arguments or {}
|
||||
}
|
||||
|
||||
# Send the function name as soon as it's parsed.
|
||||
if not self.current_tool_name_sent:
|
||||
self.current_tool_name_sent = True
|
||||
self.prev_tool_call_arr[self.current_tool_id].update(
|
||||
tool_call_arr)
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name))
|
||||
])
|
||||
|
||||
# Update our internal state with the latest parsed arguments.
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id].update( # noqa: E501
|
||||
tool_call_arr)
|
||||
|
||||
# Only send arguments when the tool call is complete.
|
||||
if end_tool_pos != -1:
|
||||
self.position += end_tool_pos + len(self.TOOL_CALL_END)
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id]["finished"] = True
|
||||
|
||||
final_args = self._cast_arguments(
|
||||
function_name,
|
||||
tool_call_arr.get("parameters", {}), # type: ignore
|
||||
request)
|
||||
if final_args:
|
||||
final_args_json = json.dumps(final_args,
|
||||
ensure_ascii=False)
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=final_args_json))
|
||||
])
|
||||
|
||||
# If tool is not finished, return None to wait for more tokens.
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.TOOL_CALLS_BEGIN not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
|
||||
if self.TOOL_CALLS_END not in rest:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
|
||||
content = (pre_text + post_text).strip()
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
|
||||
|
||||
for part in call_parts:
|
||||
if not part or self.TOOL_CALL_END not in part:
|
||||
continue
|
||||
|
||||
call_content = part.split(self.TOOL_CALL_END, 1)[0]
|
||||
if self.TOOL_SEP not in call_content:
|
||||
continue
|
||||
|
||||
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
|
||||
if type_part.strip() != "function":
|
||||
continue
|
||||
|
||||
function_name, params_dict = self._parse_steptml_invoke(
|
||||
invoke_part)
|
||||
|
||||
if function_name and params_dict is not None:
|
||||
params_dict = self._cast_arguments(function_name, params_dict,
|
||||
request)
|
||||
params_str = json.dumps(params_dict, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ToolCall(function=FunctionCall(name=function_name,
|
||||
arguments=params_str)))
|
||||
if tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
stacked_params_mapping = [
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
if ((param_name == "fused_qkv_a_proj")
|
||||
and name not in params_dict):
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
|
||||
@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
||||
Llama4ForCausalLM)
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
|
||||
self.norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.fc(
|
||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||
torch.cat((inputs_embeds, hidden_states), dim=-1))
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
return self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
loader.load_weights(model_weights.items())
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||
)
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||
)
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def compute_logits(
|
||||
|
||||
@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
||||
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
|
||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||
@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
|
||||
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
||||
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||
|
||||
521
vllm/model_executor/models/step3_text.py
Normal file
521
vllm/model_executor/models/step3_text.py
Normal file
@ -0,0 +1,521 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only Jurassic model."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FusedMoEBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: ModelConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if self.tp_size > config.moe_num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.moe_num_experts}.")
|
||||
|
||||
self.experts = FusedMoE(num_experts=config.moe_num_experts,
|
||||
top_k=config.moe_top_k,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_expert_weight,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts")
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.moe_num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class Step3TextMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(hidden_states)
|
||||
intermediate_act = self.act_fn(gate_up)
|
||||
output, _ = self.down_proj(intermediate_act)
|
||||
return output
|
||||
|
||||
|
||||
class Step3TextAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
norm_eps: float,
|
||||
rope_theta: int,
|
||||
share_q_dim: Optional[int] = None,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embedding: int = 8192,
|
||||
head_dim: int = 256,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
|
||||
if num_kv_heads != 1:
|
||||
raise ValueError(f"Step3TextAttention num_kv_heads must be 1, "
|
||||
f"but got {num_kv_heads}.")
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.q_size = share_q_dim if share_q_dim else self.head_dim
|
||||
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
hidden_size,
|
||||
self.q_size + self.kv_size * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
|
||||
self.wq = ColumnParallelLinear(
|
||||
self.q_size,
|
||||
self.head_dim * self.total_num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq",
|
||||
)
|
||||
self.rotary_emb = get_rope(self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embedding,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling)
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = self.inter_norm(q)
|
||||
q = self.wq(q)[0]
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
residual, _ = self.o_proj(attn_output)
|
||||
return residual
|
||||
|
||||
|
||||
class Step3TextDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = config.hf_config
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
||||
self.self_attn = Step3TextAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
norm_eps=config.rms_norm_eps,
|
||||
max_position_embedding=config.max_position_embedding,
|
||||
head_dim=config.head_dim,
|
||||
share_q_dim=config.share_q_dim,
|
||||
rope_theta=config.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
|
||||
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
|
||||
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
||||
if moe_layers_enum is not None:
|
||||
moe_layers_idx = [
|
||||
int(i) for i in moe_layers_enum.strip().split(',')
|
||||
]
|
||||
else:
|
||||
# Default to 1dense.
|
||||
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
||||
|
||||
if layer_idx in moe_layers_idx:
|
||||
self.moe = FusedMoEBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.moe")
|
||||
self.share_expert = Step3TextMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.share_expert_dim,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.share_expert")
|
||||
self.use_moe = True
|
||||
else:
|
||||
self.mlp = Step3TextMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.use_moe = False
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self, positions: torch.Tensor, hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if self.use_moe:
|
||||
share_output = self.share_expert(hidden_states)
|
||||
moe_output = self.moe(hidden_states)
|
||||
hidden_states = share_output + moe_output
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Step3TextModel(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Step3TextDecoderLayer(config=vllm_config.
|
||||
model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual,
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None):
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
qkv_params_mapping = [
|
||||
# (param_name, shard_name, relative_start_idx, relative_end_idx)
|
||||
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim /
|
||||
(self.config.share_q_dim + self.config.head_dim * 2)),
|
||||
(".qkv_proj", ".k_proj", self.config.share_q_dim /
|
||||
(self.config.share_q_dim + self.config.head_dim * 2),
|
||||
(self.config.share_q_dim + self.config.head_dim) /
|
||||
(self.config.share_q_dim + self.config.head_dim * 2)),
|
||||
(".qkv_proj", ".v_proj",
|
||||
(self.config.share_q_dim + self.config.head_dim) /
|
||||
(self.config.share_q_dim + self.config.head_dim * 2),
|
||||
(self.config.share_q_dim + self.config.head_dim * 2) /
|
||||
(self.config.share_q_dim + self.config.head_dim * 2)),
|
||||
]
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
expert_params_mapping = [
|
||||
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
|
||||
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
|
||||
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
|
||||
]
|
||||
|
||||
disable_moe_stacked_params = [
|
||||
data[1] for data in expert_params_mapping
|
||||
]
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if any(disable_moe_stacked_param in name
|
||||
for disable_moe_stacked_param in
|
||||
disable_moe_stacked_params):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
for expert_id in range(loaded_weight.shape[0]):
|
||||
loaded_weight_expert = loaded_weight[expert_id]
|
||||
weight_loader(param,
|
||||
loaded_weight_expert,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
for (param_name, weight_name, start_idx,
|
||||
end_idx) in qkv_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
dim = param.shape[param.output_dim]
|
||||
begin_idx = int(start_idx * dim)
|
||||
end_idx = int(end_idx * dim)
|
||||
param_slice = param.narrow(param.output_dim, begin_idx,
|
||||
end_idx - begin_idx)
|
||||
param_slice.copy_(loaded_weight)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
1052
vllm/model_executor/models/step3_vl.py
Normal file
1052
vllm/model_executor/models/step3_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .mistral_reasoning_parser import MistralReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
from .step3_reasoning_parser import Step3ReasoningParser
|
||||
|
||||
__all__ = [
|
||||
"ReasoningParser",
|
||||
@ -18,4 +19,5 @@ __all__ = [
|
||||
"Qwen3ReasoningParser",
|
||||
"Glm4MoeModelReasoningParser",
|
||||
"MistralReasoningParser",
|
||||
"Step3ReasoningParser",
|
||||
]
|
||||
|
||||
109
vllm/reasoning/step3_reasoning_parser.py
Normal file
109
vllm/reasoning/step3_reasoning_parser.py
Normal file
@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("step3")
|
||||
class Step3ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Step3 model.
|
||||
|
||||
The Step3 model uses </think> token to denote the end of reasoning
|
||||
text. This parser extracts all content before </think> as reasoning content.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
self.think_end_token = "</think>"
|
||||
|
||||
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
|
||||
re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Step3 reasoning parser could not locate think end "
|
||||
"token in the tokenizer!")
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
Handles streaming output where previous + delta = current.
|
||||
Uses token IDs for faster processing.
|
||||
For text "abc</think>xyz":
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
# Skip single special token
|
||||
if len(delta_token_ids
|
||||
) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return None
|
||||
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# </think> in delta, extract reasoning content and remaining content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
# </think> already seen in previous text, everything is content
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
# No </think> seen yet, everything is reasoning
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
|
||||
# Check if the model output contains the </think> token
|
||||
if self.think_end_token not in model_output:
|
||||
# If no </think> token, everything is reasoning content
|
||||
return model_output, None
|
||||
else:
|
||||
# Find the first occurrence of </think>
|
||||
end_index = model_output.find(self.think_end_token)
|
||||
reasoning_content = model_output[:end_index]
|
||||
|
||||
# Content after </think> token
|
||||
content = model_output[end_index + len(self.think_end_token):]
|
||||
|
||||
if len(content) == 0:
|
||||
content = None
|
||||
|
||||
return reasoning_content, content
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
if self.think_end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||
@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
Nemotron_Nano_VL_Config,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
RWConfig, UltravoxConfig)
|
||||
RWConfig, Step3TextConfig,
|
||||
Step3VLConfig, UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"ultravox": UltravoxConfig,
|
||||
"step3_vl": Step3VLConfig,
|
||||
"step3_text": Step3TextConfig,
|
||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
|
||||
Step3VisionEncoderConfig,
|
||||
Step3VLConfig)
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
__all__ = [
|
||||
@ -42,4 +45,7 @@ __all__ = [
|
||||
"Nemotron_Nano_VL_Config",
|
||||
"NVLM_D_Config",
|
||||
"UltravoxConfig",
|
||||
"Step3VLConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
"Step3TextConfig",
|
||||
]
|
||||
|
||||
123
vllm/transformers_utils/configs/step3_vl.py
Normal file
123
vllm/transformers_utils/configs/step3_vl.py
Normal file
@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Step3VisionEncoderConfig(PretrainedConfig):
|
||||
model_type = "step3_vision_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1792,
|
||||
intermediate_size=3072,
|
||||
output_hidden_size=4096,
|
||||
num_hidden_layers=63,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
image_size=728,
|
||||
patch_size=14,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.output_hidden_size = output_hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3TextConfig(PretrainedConfig):
|
||||
model_type = "step3_text"
|
||||
architectures = ["Step3TextForCausalLM"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 7168,
|
||||
intermediate_size: int = 18432,
|
||||
num_attention_heads: int = 64,
|
||||
num_attention_groups: int = 1,
|
||||
num_hidden_layers: int = 61,
|
||||
max_seq_len: int = 65536,
|
||||
vocab_size: int = 128815,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
moe_intermediate_size: int = 5120,
|
||||
moe_num_experts: int = 48,
|
||||
moe_top_k: int = 3,
|
||||
rope_theta: float = 500000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embedding: int = 65536,
|
||||
share_expert_dim: int = 5120,
|
||||
share_q_dim: int = 2048,
|
||||
head_dim: int = 256,
|
||||
norm_expert_weight: bool = False,
|
||||
moe_layers_enum: tuple[int,
|
||||
...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
||||
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
|
||||
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
|
||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
||||
45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
||||
55, 56, 57, 58, 59),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_attention_groups = num_attention_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.max_position_embedding = max_position_embedding
|
||||
self.share_expert_dim = share_expert_dim
|
||||
self.share_q_dim = share_q_dim
|
||||
self.head_dim = head_dim
|
||||
self.norm_expert_weight = norm_expert_weight
|
||||
self.moe_layers_enum = moe_layers_enum
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3VLConfig(PretrainedConfig):
|
||||
model_type = "step3_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
|
||||
text_config: Optional[Union[dict, Step3TextConfig]] = None,
|
||||
understand_projector_stride: int = 1,
|
||||
projector_bias: bool = True,
|
||||
image_token_id: int = 128001,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if vision_config is None:
|
||||
vision_config = Step3VisionEncoderConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = Step3VisionEncoderConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = Step3TextConfig()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = Step3TextConfig(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.understand_projector_stride = understand_projector_stride
|
||||
self.projector_bias = projector_bias
|
||||
self.hidden_size = text_config.hidden_size
|
||||
self.image_token_id = image_token_id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -51,6 +53,9 @@ class EagleProposer:
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
|
||||
self.is_multimodal_model = vllm_config.model_config \
|
||||
.is_multimodal_model
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
@ -76,6 +81,11 @@ class EagleProposer:
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
@ -88,6 +98,7 @@ class EagleProposer:
|
||||
next_token_ids: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@ -128,14 +139,27 @@ class EagleProposer:
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
if self.is_multimodal_model:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds or None,
|
||||
)
|
||||
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:num_input_tokens],
|
||||
self.positions[:num_input_tokens],
|
||||
self.hidden_states[:num_input_tokens],
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
@ -218,15 +242,24 @@ class EagleProposer:
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
if self.is_multimodal_model:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
self.inputs_embeds[:batch_size] = inputs_embeds
|
||||
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:input_batch_size]
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
self.input_ids[:input_batch_size],
|
||||
self.positions[:input_batch_size],
|
||||
self.hidden_states[:input_batch_size],
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:input_batch_size],
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
@ -391,10 +424,18 @@ class EagleProposer:
|
||||
) -> None:
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
self.model(
|
||||
self.input_ids[:num_tokens],
|
||||
self.positions[:num_tokens],
|
||||
self.hidden_states[:num_tokens],
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
def validate_same_kv_cache_group(self,
|
||||
|
||||
@ -1314,13 +1314,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
shift_computed_tokens: int = 0,
|
||||
) -> list[torch.Tensor]:
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_computed_tokens = \
|
||||
req_state.num_computed_tokens + shift_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info.offset
|
||||
@ -2298,6 +2300,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
mm_embeds = None
|
||||
if self.is_multimodal_model:
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output,
|
||||
shift_computed_tokens=1)
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
@ -2305,6 +2312,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
next_token_ids=next_token_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embeds=mm_embeds,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user