mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 21:48:12 +08:00
Merge branch 'main' into mlm-full-lora-support
This commit is contained in:
commit
421707dec1
@ -15,6 +15,21 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build arm64 wheel - CUDA 13.0"
|
||||
depends_on: ~
|
||||
id: build-wheel-arm64-cuda-13-0
|
||||
agents:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
|
||||
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# aarch64 build
|
||||
- label: "Build arm64 CPU wheel"
|
||||
depends_on: ~
|
||||
@ -25,7 +40,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -39,7 +54,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_31"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -52,7 +67,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
|
||||
@ -372,6 +372,17 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}")
|
||||
|
||||
# keep only "official" files for a non-nightly version (specifed by cli args)
|
||||
PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$")
|
||||
if PY_VERSION_RE.match(version):
|
||||
# upload-wheels.sh ensures no "dev" is in args.version
|
||||
wheel_files = list(
|
||||
filter(lambda x: version in x and "dev" not in x, wheel_files)
|
||||
)
|
||||
print(f"Non-nightly version detected, wheel files used: {wheel_files}")
|
||||
else:
|
||||
print("Nightly version detected, keeping all wheel files.")
|
||||
|
||||
# Generate index and metadata, assuming wheels and indices are stored as:
|
||||
# s3://vllm-wheels/{version}/<wheel files>
|
||||
# s3://vllm-wheels/<anything>/<index files>
|
||||
|
||||
@ -34,9 +34,10 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then
|
||||
fi
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# current build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# default build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# we also accept params as manylinux tag
|
||||
# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels
|
||||
manylinux_version="manylinux_2_31"
|
||||
manylinux_version="${1:-manylinux_2_31}"
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
if [[ "$wheel" != *"linux"* ]]; then
|
||||
@ -96,8 +97,11 @@ if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]];
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/"
|
||||
fi
|
||||
|
||||
# copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
# re-generate and copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
if [[ "$version" != *"dev"* ]]; then
|
||||
echo "Uploading indices to overwrite /$pure_version/"
|
||||
echo "Re-generating indices for /$pure_version/"
|
||||
rm -rf "$INDICES_OUTPUT_DIR/*"
|
||||
mkdir -p "$INDICES_OUTPUT_DIR"
|
||||
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/"
|
||||
fi
|
||||
|
||||
@ -326,10 +326,10 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
|
||||
- label: V1 Test e2e + engine # 30min
|
||||
timeout_in_minutes: 45
|
||||
- label: V1 Test e2e + engine # 65min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -435,7 +435,7 @@ steps:
|
||||
|
||||
- label: Examples Test # 30min
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -455,7 +455,6 @@ steps:
|
||||
# for multi-modal models
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
# for pooling models
|
||||
|
||||
@ -99,7 +99,6 @@ def benchmark_mrope(
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=rope_parameters,
|
||||
|
||||
@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||
dtype = torch.bfloat16
|
||||
max_position = 8192
|
||||
base = 10000
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=device)
|
||||
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||
|
||||
|
||||
@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
|
||||
run_python(_VLLM_TORCH_GOMP_PATH
|
||||
"
|
||||
import os, glob
|
||||
try:
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
torch_libs = os.path.join(site_root, 'torch.libs')
|
||||
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
|
||||
except:
|
||||
print('')
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
|
||||
# Search both torch.libs and torch/lib
|
||||
roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')]
|
||||
candidates = []
|
||||
for root in roots:
|
||||
if not os.path.isdir(root):
|
||||
continue
|
||||
candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*')))
|
||||
|
||||
print(candidates[0] if candidates else '')
|
||||
"
|
||||
"failed to probe torch.libs for libgomp")
|
||||
"failed to probe for libgomp")
|
||||
|
||||
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
|
||||
return()
|
||||
|
||||
@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
largest = value;
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
|
||||
@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
}
|
||||
@ -100,7 +100,23 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built Arm CPU images.
|
||||
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image.
|
||||
|
||||
Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo).
|
||||
Please replace `<version>` in the command below with a specific version string (e.g., `0.12.0`).
|
||||
|
||||
```bash
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v<version>
|
||||
```
|
||||
|
||||
You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days.
|
||||
|
||||
The latest code can contain bugs and may not be stable. Please use it with caution.
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=6299628d326f429eba78736acb44e76749b281f5 # use full commit hash from the main branch
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}-arm64-cpu
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
@ -568,7 +568,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
```
|
||||
|
||||
!!! note
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/qwen3_reranker.py](../../examples/pooling/score/qwen3_reranker.py).
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py).
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||
|
||||
@ -851,7 +851,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin
|
||||
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
|
||||
popular open-source tools.
|
||||
|
||||
Code example: [examples/pooling/score/jinaai_rerank_client.py](../../examples/pooling/score/jinaai_rerank_client.py)
|
||||
Code example: [examples/pooling/score/openai_reranker.py](../../examples/pooling/score/openai_reranker.py)
|
||||
|
||||
#### Example Request
|
||||
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
|
||||
@ -33,6 +33,7 @@ import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
@ -222,6 +223,11 @@ if __name__ == "__main__":
|
||||
|
||||
from multiprocessing import Process
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from multiprocessing import set_start_method
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# --worker \
|
||||
# /abs/path/to/huggingface/cache \
|
||||
# -e VLLM_HOST_IP=<worker_node_ip>
|
||||
#
|
||||
#
|
||||
# Each worker requires a unique VLLM_HOST_IP value.
|
||||
# Keep each terminal session open. Closing a session stops the associated Ray
|
||||
# node and thereby shuts down the entire cluster.
|
||||
@ -59,6 +59,34 @@ if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract VLLM_HOST_IP from ADDITIONAL_ARGS (e.g. "-e VLLM_HOST_IP=...").
|
||||
VLLM_HOST_IP=""
|
||||
for ((i = 0; i < ${#ADDITIONAL_ARGS[@]}; i++)); do
|
||||
arg="${ADDITIONAL_ARGS[$i]}"
|
||||
case "${arg}" in
|
||||
-e)
|
||||
next="${ADDITIONAL_ARGS[$((i + 1))]:-}"
|
||||
if [[ "${next}" == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${next#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
;;
|
||||
-eVLLM_HOST_IP=* | VLLM_HOST_IP=*)
|
||||
VLLM_HOST_IP="${arg#*=}"
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# For the head node, HEAD_NODE_ADDRESS and VLLM_HOST_IP should be consistent.
|
||||
if [[ "${NODE_TYPE}" == "--head" && -n "${VLLM_HOST_IP}" ]]; then
|
||||
if [[ "${VLLM_HOST_IP}" != "${HEAD_NODE_ADDRESS}" ]]; then
|
||||
echo "Warning: VLLM_HOST_IP (${VLLM_HOST_IP}) differs from head_node_ip (${HEAD_NODE_ADDRESS})."
|
||||
echo "Using VLLM_HOST_IP as the head node address."
|
||||
HEAD_NODE_ADDRESS="${VLLM_HOST_IP}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Generate a unique container name with random suffix.
|
||||
# Docker container names must be unique on each host.
|
||||
# The random suffix allows multiple Ray containers to run simultaneously on the same machine,
|
||||
@ -74,36 +102,17 @@ cleanup() {
|
||||
trap cleanup EXIT
|
||||
|
||||
# Build the Ray start command based on the node role.
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# while workers connect to the head's address.
|
||||
RAY_START_CMD="ray start --block"
|
||||
if [ "${NODE_TYPE}" == "--head" ]; then
|
||||
RAY_START_CMD+=" --head --port=6379"
|
||||
RAY_START_CMD+=" --head --node-ip-address=${HEAD_NODE_ADDRESS} --port=6379"
|
||||
else
|
||||
|
||||
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
|
||||
fi
|
||||
|
||||
# Parse VLLM_HOST_IP from additional args if present.
|
||||
# This is needed for multi-NIC configurations where Ray needs explicit IP bindings.
|
||||
VLLM_HOST_IP=""
|
||||
for arg in "${ADDITIONAL_ARGS[@]}"; do
|
||||
if [[ $arg == "-e" ]]; then
|
||||
continue
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_START_CMD+=" --node-ip-address=${VLLM_HOST_IP}"
|
||||
fi
|
||||
if [[ $arg == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${arg#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Build Ray IP environment variables if VLLM_HOST_IP is set.
|
||||
# These variables ensure Ray binds to the correct network interface on multi-NIC systems.
|
||||
RAY_IP_VARS=()
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_IP_VARS=(
|
||||
-e "RAY_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
-e "RAY_OVERRIDE_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
)
|
||||
fi
|
||||
|
||||
# Launch the container with the assembled parameters.
|
||||
@ -118,6 +127,5 @@ docker run \
|
||||
--shm-size 10.24g \
|
||||
--gpus all \
|
||||
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
|
||||
"${RAY_IP_VARS[@]}" \
|
||||
"${ADDITIONAL_ARGS[@]}" \
|
||||
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"
|
||||
|
||||
@ -50,4 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
mcp
|
||||
@ -1,2 +1,2 @@
|
||||
lmcache >= 0.3.10.post1
|
||||
lmcache
|
||||
nixl >= 0.7.1 # Required for disaggregated prefill
|
||||
|
||||
@ -138,6 +138,17 @@ elif current_platform.is_rocm():
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
def has_cuda_graph_wrapper_metadata() -> bool:
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module = import_module("torch._inductor.utils")
|
||||
module.CUDAGraphWrapperMetadata # noqa B018
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
@ -145,7 +156,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"inductor_graph_partition",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not has_cuda_graph_wrapper_metadata(),
|
||||
reason="This test requires"
|
||||
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
|
||||
),
|
||||
),
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
|
||||
@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||
def __init__(self, head_dim=64, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.rotary_dim = rotary_dim or head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
|
||||
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
num_layers: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
num_topk: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_tokens: int
|
||||
|
||||
|
||||
def make_fused_moe_layer(
|
||||
rank: int,
|
||||
layer_idx: int,
|
||||
test_config: TestConfig,
|
||||
) -> FusedMoE:
|
||||
quant_config = None
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
quant_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
fml = FusedMoE(
|
||||
num_experts=test_config.num_experts,
|
||||
top_k=test_config.num_topk,
|
||||
hidden_size=test_config.hidden_size,
|
||||
intermediate_size=test_config.intermediate_size,
|
||||
prefix=f"dummy_layer_{layer_idx}",
|
||||
activation="silu",
|
||||
is_act_and_mul=True,
|
||||
params_dtype=torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
|
||||
nvfp4_fused_moe.create_weights(
|
||||
fml,
|
||||
test_config.num_local_experts,
|
||||
test_config.hidden_size,
|
||||
test_config.intermediate_size,
|
||||
params_dtype=torch.uint8,
|
||||
global_num_experts=test_config.num_experts,
|
||||
)
|
||||
|
||||
fml = fml.to(device)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
test_config.num_local_experts,
|
||||
test_config.intermediate_size,
|
||||
test_config.hidden_size,
|
||||
in_dtype=torch.bfloat16,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
fml.w13_weight.data = w1_q
|
||||
fml.w2_weight.data = w2_q
|
||||
|
||||
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
|
||||
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
|
||||
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
|
||||
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
|
||||
fml.w2_weight_scale.data = (
|
||||
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w2_weight_scale.data.dtype)
|
||||
fml.w13_weight_scale.data = (
|
||||
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w13_weight_scale.data.dtype)
|
||||
|
||||
nvfp4_fused_moe.process_weights_after_loading(fml)
|
||||
|
||||
fml.maybe_init_modular_kernel()
|
||||
|
||||
return fml
|
||||
|
||||
|
||||
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_dp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
fml_layers = [
|
||||
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
|
||||
for layer_idx in range(test_config.num_layers)
|
||||
]
|
||||
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
|
||||
|
||||
hidden_states = []
|
||||
router_logits = []
|
||||
for layer_idx in range(test_config.num_layers):
|
||||
hidden_states.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
router_logits.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.num_experts),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
out_before_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_before_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
indices = torch.zeros(
|
||||
test_config.num_layers, test_config.num_experts, dtype=torch.long
|
||||
)
|
||||
for lidx in range(test_config.num_layers):
|
||||
indices[lidx] = torch.Tensor(range(test_config.num_experts))
|
||||
|
||||
shuffled_indices = torch.zeros_like(indices)
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
logical_to_physical_map_list = []
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
physical_to_logical_map = shuffled_indices[lidx].to(device)
|
||||
logical_to_physical_map = torch.empty(
|
||||
(num_global_experts,), dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map[physical_to_logical_map] = torch.arange(
|
||||
0, num_global_experts, dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map_list.append(
|
||||
logical_to_physical_map.reshape(num_global_experts, 1)
|
||||
)
|
||||
|
||||
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
|
||||
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
logical_replica_count = torch.ones(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
fml.enable_eplb = True
|
||||
fml.set_eplb_state(
|
||||
lidx,
|
||||
torch.zeros(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
),
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
out_after_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_after_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
for lidx in range(test_config.num_layers):
|
||||
torch.testing.assert_close(
|
||||
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("num_layers", [8])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("intermediate_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("backend", ["latency", "throughput"])
|
||||
def test_eplb_fml(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_tokens: int,
|
||||
backend: str,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
num_local_experts = num_experts // world_size
|
||||
num_topk = 4
|
||||
|
||||
test_config = TestConfig(
|
||||
num_layers=num_layers,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
num_topk=num_topk,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
distributed_run(
|
||||
_test_eplb_fml,
|
||||
world_size,
|
||||
test_config,
|
||||
)
|
||||
@ -1,21 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai_harmony import Author, Message, Role, TextContent
|
||||
|
||||
from tests.entrypoints.openai.utils import verify_harmony_messages
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
auto_drop_analysis_messages,
|
||||
get_encoding,
|
||||
has_custom_tools,
|
||||
parse_chat_input_to_harmony_message,
|
||||
parse_chat_output,
|
||||
parse_input_to_harmony_message,
|
||||
parse_output_message,
|
||||
)
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""Tests for parse_input_to_harmony_message function."""
|
||||
class TestCommonParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are common to both Chat Completion
|
||||
parse_chat_input_to_harmony_message and Responsees API
|
||||
parse_input_to_harmony_message functions.
|
||||
"""
|
||||
|
||||
def test_assistant_message_with_tool_calls(self):
|
||||
@pytest.fixture(
|
||||
params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message]
|
||||
)
|
||||
def parse_function(self, request):
|
||||
return request.param
|
||||
|
||||
def test_assistant_message_with_tool_calls(self, parse_function):
|
||||
"""Test parsing assistant message with tool calls."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 2
|
||||
|
||||
@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
|
||||
assert messages[1].recipient == "functions.search_web"
|
||||
assert messages[1].content_type == "json"
|
||||
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self):
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self, parse_function):
|
||||
"""Test parsing assistant message with tool call having None arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].recipient == "functions.get_current_time"
|
||||
|
||||
def test_system_message(self, parse_function):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self, parse_function):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self, parse_function):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self, parse_function):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self, parse_function):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self, parse_function):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self, parse_function):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_array_content_with_missing_text(self, parse_function):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Responses API
|
||||
parse_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
"""Test parsing tool message with string content."""
|
||||
chat_msg = {
|
||||
@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.search_results"
|
||||
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.empty_tool"
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_system_message(self):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
class TestParseChatInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Chat Completion API
|
||||
parse_chat_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
def test_user_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
def test_user_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_content_but_empty_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "The answer is 4.",
|
||||
"reasoning": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": "The answer is 4.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_but_no_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_array_content_with_missing_text(self):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
def test_assistant_message_with_tool_calls_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "get_weather",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.get_weather",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_array_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "search_results",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
{"type": "text", "text": "Result 1: "},
|
||||
{"type": "text", "text": "Result 2: "},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "http://example.com/img.png",
|
||||
}, # Should be ignored
|
||||
{"type": "text", "text": "Result 3"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.search_results",
|
||||
"content": "Result 1: Result 2: Result 3",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_none_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestAutoDropAnalysisMessages:
|
||||
def test_no_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_analysis_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_multiple_analysis_messages_without_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_drops_one_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
def test_drops_all_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first 3 analysis messages
|
||||
assert cleaned_messages == messages[3:]
|
||||
|
||||
def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 5."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped all those analysis messages
|
||||
assert len(cleaned_messages) == 2
|
||||
assert cleaned_messages[0].content[0].text == "The answer is 4."
|
||||
assert cleaned_messages[1].content[0].text == "The answer is 5."
|
||||
|
||||
def test_drops_non_assistant_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.TOOL, "The tool thinks we should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
|
||||
class TestParseChatOutput:
|
||||
def test_parse_chat_output_interrupted_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>I'm in the middle of answering"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm in the middle of thinking"
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
|
||||
"<|start|>assistant<|channel|>final"
|
||||
"<|message|>I'm in the middle of answering"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm thinking."
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_complete_content(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
def test_parse_chat_output_complete_commentary(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I need to call some tools."
|
||||
|
||||
def test_parse_chat_output_complete_reasoning(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_and_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
|
||||
class TestParseOutputMessage:
|
||||
|
||||
@ -11,13 +11,25 @@ import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .utils import (
|
||||
accumulate_streaming_response,
|
||||
verify_chat_response,
|
||||
verify_harmony_messages,
|
||||
)
|
||||
|
||||
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
|
||||
# Verify that data_parallel_rank defaults to None
|
||||
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
|
||||
|
||||
|
||||
class TestServingChatWithHarmony:
|
||||
"""
|
||||
These tests ensure Chat Completion requests are being properly converted into
|
||||
Harmony messages and Harmony response messages back into Chat Completion responses.
|
||||
These tests are not exhaustive, but each one was created to cover a specific case
|
||||
that we got wrong but is now fixed.
|
||||
|
||||
Any changes to the tests and their expectations may result in changes to the
|
||||
accuracy of model prompting and responses generated. It is suggested to run
|
||||
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
|
||||
any impact of changes in how we prompt Harmony models.
|
||||
"""
|
||||
|
||||
@pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"])
|
||||
def stream(self, request) -> bool:
|
||||
"""Parameterize tests to run in both non-streaming and streaming modes."""
|
||||
return request.param
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_engine(self) -> AsyncLLM:
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
return mock_engine
|
||||
|
||||
@pytest.fixture()
|
||||
def serving_chat(self, mock_engine) -> OpenAIServingChat:
|
||||
chat = _build_serving_chat(mock_engine)
|
||||
chat.use_harmony = True
|
||||
chat.tool_parser = ToolParserManager.get_tool_parser("openai")
|
||||
return chat
|
||||
|
||||
def mock_request_output_from_req_and_token_ids(
|
||||
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
|
||||
) -> RequestOutput:
|
||||
# Our tests don't use most fields, so just get the token ids correct
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
return RequestOutput(
|
||||
request_id=req.request_id,
|
||||
prompt=[],
|
||||
prompt_token_ids=[],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tools(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def weather_messages_start(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris today?",
|
||||
},
|
||||
]
|
||||
|
||||
async def generate_response_from_harmony_str(
|
||||
self,
|
||||
serving_chat: OpenAIServingChat,
|
||||
req: ChatCompletionRequest,
|
||||
harmony_str: str,
|
||||
stream: bool = False,
|
||||
) -> ChatCompletionResponse:
|
||||
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
|
||||
async def result_generator():
|
||||
if stream:
|
||||
for token_id in harmony_token_ids:
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, [token_id]
|
||||
)
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, [], finished=True
|
||||
)
|
||||
else:
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, harmony_token_ids, finished=True
|
||||
)
|
||||
|
||||
generator_func = (
|
||||
serving_chat.chat_completion_stream_generator
|
||||
if stream
|
||||
else serving_chat.chat_completion_full_generator
|
||||
)
|
||||
|
||||
result = generator_func(
|
||||
request=req,
|
||||
result_generator=result_generator(),
|
||||
request_id=req.request_id,
|
||||
model_name=req.model,
|
||||
conversation=[],
|
||||
tokenizer=get_tokenizer(req.model),
|
||||
request_metadata=RequestResponseMetadata(
|
||||
request_id=req.request_id,
|
||||
model_name=req.model,
|
||||
),
|
||||
)
|
||||
|
||||
if stream:
|
||||
return await accumulate_streaming_response(result)
|
||||
return await result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_chat(self, serving_chat, stream):
|
||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "We need to think really hard about this."
|
||||
final_str = "The answer is 2."
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(response, content=final_str, reasoning=reasoning_str)
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
# The analysis message should be dropped on subsequent inputs because
|
||||
# of the subsequent assistant message to the final channel.
|
||||
{"role": "assistant", "channel": "final", "content": final_str},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_response_with_content(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
commentary_str = "We'll call get_weather."
|
||||
tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>commentary<|message|>{commentary_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
content=commentary_str,
|
||||
tool_calls=[("get_weather", tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": commentary_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_and_reasoning(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_tools_and_reasoning(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
paris_tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", paris_tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": paris_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the second turn's output
|
||||
paris_weather_str = "The weather in Paris today is 20 degrees Celsius."
|
||||
response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>"
|
||||
response_2 = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req_2, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(response_2, content=paris_weather_str)
|
||||
|
||||
# Add the output messages from the second turn as input to the third turn
|
||||
for choice in response_2.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add a new user message for the third turn
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today?",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the third turn's input
|
||||
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_3, _, _ = serving_chat._make_request_with_harmony(req_3)
|
||||
verify_harmony_messages(
|
||||
input_messages_3,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": paris_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": paris_weather_str,
|
||||
},
|
||||
{"role": "user", "content": messages[-1]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the third turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
boston_tool_args_str = '{"location": "Boston"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>"
|
||||
)
|
||||
response_3 = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response_3,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", boston_tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response_3.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the third turn as input to the fourth turn
|
||||
for choice in response_3.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "10 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the fourth turn's input
|
||||
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_4, _, _ = serving_chat._make_request_with_harmony(req_4)
|
||||
verify_harmony_messages(
|
||||
input_messages_4,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{"role": "assistant"},
|
||||
{"role": "tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": boston_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "10 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": "4",
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
# The reasoning that would have resulted in an analysis message is
|
||||
# dropped because of a later assistant message to the final channel.
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": messages[1]["content"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning_empty_content(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": "",
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": messages[1]["reasoning"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning_empty_content_list(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": [],
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": messages[1]["reasoning"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
190
tests/entrypoints/openai/utils.py
Normal file
190
tests/entrypoints/openai/utils.py
Normal file
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
async def accumulate_streaming_response(
|
||||
stream_generator: AsyncGenerator[str, None],
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
|
||||
|
||||
This helper parses the SSE format and builds up the complete response
|
||||
by combining all the delta chunks.
|
||||
"""
|
||||
accumulated_content = ""
|
||||
accumulated_reasoning = None
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
role = None
|
||||
finish_reason = None
|
||||
response_id = None
|
||||
created = None
|
||||
model = None
|
||||
index = 0
|
||||
|
||||
async for chunk_str in stream_generator:
|
||||
# Skip empty lines and [DONE] marker
|
||||
if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
|
||||
continue
|
||||
|
||||
# Parse SSE format: "data: {json}\n\n"
|
||||
if chunk_str.startswith("data: "):
|
||||
json_str = chunk_str[6:].strip()
|
||||
try:
|
||||
chunk_data = json.loads(json_str)
|
||||
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
|
||||
chunk = ChatCompletionStreamResponse(**chunk_data)
|
||||
|
||||
# Store metadata from first chunk
|
||||
if response_id is None:
|
||||
response_id = chunk.id
|
||||
created = chunk.created
|
||||
model = chunk.model
|
||||
|
||||
# Process each choice in the chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.delta.role:
|
||||
role = choice.delta.role
|
||||
if choice.delta.content:
|
||||
accumulated_content += choice.delta.content
|
||||
if choice.delta.reasoning:
|
||||
if accumulated_reasoning is None:
|
||||
accumulated_reasoning = ""
|
||||
accumulated_reasoning += choice.delta.reasoning
|
||||
if choice.delta.tool_calls:
|
||||
# Accumulate tool calls
|
||||
for tool_call_delta in choice.delta.tool_calls:
|
||||
# Find or create the tool call at this index
|
||||
while len(accumulated_tool_calls) <= tool_call_delta.index:
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
)
|
||||
|
||||
if tool_call_delta.id:
|
||||
accumulated_tool_calls[tool_call_delta.index]["id"] = (
|
||||
tool_call_delta.id
|
||||
)
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
accumulated_tool_calls[tool_call_delta.index][
|
||||
"function"
|
||||
]["name"] += tool_call_delta.function.name
|
||||
if tool_call_delta.function.arguments:
|
||||
accumulated_tool_calls[tool_call_delta.index][
|
||||
"function"
|
||||
]["arguments"] += tool_call_delta.function.arguments
|
||||
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
if choice.index is not None:
|
||||
index = choice.index
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Build the final message
|
||||
message_kwargs = {
|
||||
"role": role or "assistant",
|
||||
"content": accumulated_content if accumulated_content else None,
|
||||
"reasoning": accumulated_reasoning,
|
||||
}
|
||||
|
||||
# Only include tool_calls if there are any
|
||||
if accumulated_tool_calls:
|
||||
message_kwargs["tool_calls"] = [
|
||||
{"id": tc["id"], "type": tc["type"], "function": tc["function"]}
|
||||
for tc in accumulated_tool_calls
|
||||
]
|
||||
|
||||
message = ChatMessage(**message_kwargs)
|
||||
|
||||
# Build the final response
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
|
||||
# Create usage info (with dummy values for tests)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=response_id or "chatcmpl-test",
|
||||
object="chat.completion",
|
||||
created=created or 0,
|
||||
model=model or "test-model",
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def verify_harmony_messages(
|
||||
messages: list[Any], expected_messages: list[dict[str, Any]]
|
||||
):
|
||||
assert len(messages) == len(expected_messages)
|
||||
for msg, expected in zip(messages, expected_messages):
|
||||
if "role" in expected:
|
||||
assert msg.author.role == expected["role"]
|
||||
if "author_name" in expected:
|
||||
assert msg.author.name == expected["author_name"]
|
||||
if "channel" in expected:
|
||||
assert msg.channel == expected["channel"]
|
||||
if "recipient" in expected:
|
||||
assert msg.recipient == expected["recipient"]
|
||||
if "content" in expected:
|
||||
assert msg.content[0].text == expected["content"]
|
||||
if "content_type" in expected:
|
||||
assert msg.content_type == expected["content_type"]
|
||||
if "tool_definitions" in expected:
|
||||
# Check that the tool definitions match the expected list of tool names
|
||||
actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
|
||||
assert actual_tools == expected["tool_definitions"]
|
||||
|
||||
|
||||
def verify_chat_response(
|
||||
response: ChatCompletionResponse,
|
||||
content: str | None = None,
|
||||
reasoning: str | None = None,
|
||||
tool_calls: list[tuple[str, str]] | None = None,
|
||||
):
|
||||
assert len(response.choices) == 1
|
||||
message = response.choices[0].message
|
||||
|
||||
if content is not None:
|
||||
assert message.content == content
|
||||
else:
|
||||
assert not message.content
|
||||
|
||||
if reasoning is not None:
|
||||
assert message.reasoning == reasoning
|
||||
else:
|
||||
assert not message.reasoning
|
||||
|
||||
if tool_calls:
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) == len(tool_calls)
|
||||
for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
|
||||
assert tc.function.name == expected_name
|
||||
assert tc.function.arguments == expected_args
|
||||
else:
|
||||
assert not message.tool_calls
|
||||
@ -116,7 +116,6 @@ def test_mrope(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
|
||||
@ -83,8 +83,12 @@ def test_rotary_embedding(
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": rope_theta,
|
||||
"partial_rotary_factor": rotary_dim / head_size,
|
||||
}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
@ -150,9 +154,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
@ -177,9 +181,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
|
||||
@ -18,7 +18,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
IS_SUPPORTED_BY_GPU = (
|
||||
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -173,10 +173,7 @@ class _HfExamplesInfo:
|
||||
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AfmoeForCausalLM": _HfExamplesInfo(
|
||||
"arcee-ai/Trinity-Nano",
|
||||
is_available_online=False,
|
||||
),
|
||||
"AfmoeForCausalLM": _HfExamplesInfo("arcee-ai/Trinity-Nano-Preview"),
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
|
||||
|
||||
@ -18,47 +18,53 @@ def mistral_tokenizer():
|
||||
return mistral_tokenizer
|
||||
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
INVALID_SIMPLE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning sectionThis is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
INVALID_COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "This is content",
|
||||
"reasoning": "This is content",
|
||||
"output": "[THINK]This is reasoning",
|
||||
"reasoning": "This is reasoning",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is content",
|
||||
"reasoning": None,
|
||||
"content": "This is content",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING_STREAMING = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
INVALID_MULTIPLE_LINES = {
|
||||
"output": "This\nThat[/THINK]This is the rest\nThat",
|
||||
"reasoning": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This\nThatThis is the rest\nThat",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING = {
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
INVALID_SHORTEST_REASONING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
|
||||
@ -78,17 +84,17 @@ MULTIPLE_LINES_WITH_THINK = {
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_WITH_THINK = {
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
INVALID_SHORTEST_REASONING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
THINK_NO_END = {
|
||||
"output": "[THINK]This is a reasoning section",
|
||||
@ -98,8 +104,8 @@ THINK_NO_END = {
|
||||
}
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning": "",
|
||||
"content": None,
|
||||
"reasoning": None,
|
||||
"content": "",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY_STREAMING = {
|
||||
@ -109,47 +115,48 @@ EMPTY_STREAMING = {
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NEW_LINE = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
# Streaming cannot handle new lines at the beginning of the output
|
||||
# because we need to support [THINK]...[/THINK] and [/THINK]...
|
||||
# We cannot know if the text before [THINK] is reasoning content
|
||||
# or not.
|
||||
NEW_LINE_STREAMING = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "\nThis is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
INVALID_SIMPLE_REASONING,
|
||||
id="invalid_simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
INVALID_SIMPLE_REASONING,
|
||||
id="invalid_simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
INVALID_COMPLETE_REASONING,
|
||||
id="invalid_complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
INVALID_COMPLETE_REASONING,
|
||||
id="invalid_complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_CONTENT,
|
||||
id="no_content_token",
|
||||
id="no_content",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_REASONING,
|
||||
id="no_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
@ -158,23 +165,23 @@ TEST_CASES = [
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
INVALID_MULTIPLE_LINES,
|
||||
id="invalid_multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
INVALID_MULTIPLE_LINES,
|
||||
id="invalid_multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING,
|
||||
id="shortest",
|
||||
INVALID_SHORTEST_REASONING,
|
||||
id="invalid_shortest",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING,
|
||||
id="shortest_streaming",
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING,
|
||||
id="invalid_shortest_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
@ -208,13 +215,13 @@ TEST_CASES = [
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="shortest_with_think",
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="invalid_shortest_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING_WITH_THINK,
|
||||
id="shortest_with_think_streaming",
|
||||
INVALID_SHORTEST_REASONING_WITH_THINK,
|
||||
id="invalid_shortest_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
@ -316,10 +323,26 @@ def test_mistral_reasoning(
|
||||
|
||||
# Test extract_content
|
||||
if param_dict["content"] is not None:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == mistral_tokenizer.tokenizer.encode(
|
||||
param_dict["content"], bos=False, eos=False
|
||||
# Handle the case where there are tokens outputted before Thinking.
|
||||
# This should not occur if the model is well trained and prompted.
|
||||
if "[THINK]" in param_dict["output"] and not param_dict["output"].startswith(
|
||||
"[THINK]"
|
||||
):
|
||||
before_content = param_dict["output"].split("[THINK]")[0]
|
||||
before_token_ids = mistral_tokenizer.tokenizer.encode(
|
||||
before_content, bos=False, eos=False
|
||||
)
|
||||
left_to_encode = param_dict["content"][len(before_content) :]
|
||||
# Normal situation.
|
||||
else:
|
||||
before_token_ids = []
|
||||
left_to_encode = param_dict["content"]
|
||||
|
||||
content_tokens = parser.extract_content_ids(output_tokens)
|
||||
expected_token_ids = before_token_ids + mistral_tokenizer.tokenizer.encode(
|
||||
left_to_encode, bos=False, eos=False
|
||||
)
|
||||
assert content_tokens == expected_token_ids
|
||||
else:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == []
|
||||
|
||||
@ -3,12 +3,45 @@
|
||||
# for users who do not have any compilers installed on their system
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
merge_base_commit=$(git merge-base HEAD origin/main)
|
||||
echo "Current merge base commit with main: $merge_base_commit"
|
||||
echo "INFO: current merge base commit with main: $merge_base_commit"
|
||||
git show --oneline -s $merge_base_commit
|
||||
|
||||
# test whether the metadata.json url is valid, retry each 3 minutes up to 5 times
|
||||
# this avoids cumbersome error messages & manual retries in case the precompiled wheel
|
||||
# for the given commit is still being built in the release pipeline
|
||||
meta_json_url="https://wheels.vllm.ai/$merge_base_commit/vllm/metadata.json"
|
||||
echo "INFO: will use metadata.json from $meta_json_url"
|
||||
|
||||
for i in {1..5}; do
|
||||
echo "Checking metadata.json URL (attempt $i)..."
|
||||
if curl --fail "$meta_json_url" > metadata.json; then
|
||||
echo "INFO: metadata.json URL is valid."
|
||||
# check whether it is valid json by python
|
||||
if python3 -m json.tool metadata.json; then
|
||||
echo "INFO: metadata.json is valid JSON. Proceeding with the test."
|
||||
else
|
||||
echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
|
||||
exit 1
|
||||
fi
|
||||
break
|
||||
fi
|
||||
# failure handling
|
||||
if [ $i -eq 5 ]; then
|
||||
echo "ERROR: metadata.json URL is still not valid after 5 attempts."
|
||||
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit exists."
|
||||
echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes."
|
||||
echo " NOTE: If it fails, please report in #sig-ci channel."
|
||||
exit 1
|
||||
else
|
||||
echo "WARNING: metadata.json URL is not valid. Retrying in 3 minutes..."
|
||||
sleep 180
|
||||
fi
|
||||
done
|
||||
|
||||
set -x
|
||||
|
||||
cd /vllm-workspace/
|
||||
|
||||
# uninstall vllm
|
||||
@ -29,6 +62,6 @@ python3 -c 'import vllm'
|
||||
|
||||
# Check if the clangd log file was created
|
||||
if [ ! -f /tmp/changed.file ]; then
|
||||
echo "changed.file was not created, python only compilation failed"
|
||||
echo "ERROR: changed.file was not created, python only compilation failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -13,6 +13,7 @@ import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
|
||||
# Detect Blackwell / B200 (compute capability 10.x)
|
||||
try:
|
||||
@ -44,6 +45,7 @@ DEEPEP_BACKENDS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_ep(), reason="These tests require deep_ep to run")
|
||||
@pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS)
|
||||
@pytest.mark.xfail(
|
||||
IS_BLACKWELL,
|
||||
|
||||
@ -16,6 +16,16 @@ from vllm.platforms import current_platform
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
"""Skip test if available GPUs < tp_size on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if available_gpus < tp_size:
|
||||
pytest.skip(
|
||||
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
|
||||
)
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
if mm_enabled:
|
||||
@ -280,9 +290,20 @@ def test_speculators_model_integration(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill"],
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
@ -292,6 +313,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
@ -305,6 +327,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
@ -318,6 +341,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@ -329,6 +353,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
@ -339,6 +364,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
@ -350,6 +376,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@ -361,10 +388,12 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
@ -381,6 +410,7 @@ def test_eagle_correctness(
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
if attn_backend == "TREE_ATTN":
|
||||
@ -389,6 +419,17 @@ def test_eagle_correctness(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
)
|
||||
if model_impl == "transformers":
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("5.0.0.dev")
|
||||
if installed < required:
|
||||
pytest.skip(
|
||||
"Eagle3 with the Transformers modeling backend requires "
|
||||
f"transformers>={required}, but got {installed}"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
@ -424,6 +465,8 @@ def test_eagle_correctness(
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
_skip_if_insufficient_gpus_for_tp(tp_size)
|
||||
|
||||
max_model_len = 2048
|
||||
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
|
||||
|
||||
@ -448,6 +491,7 @@ def test_eagle_correctness(
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
model_impl=model_impl,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
@ -493,6 +537,7 @@ def test_mtp_correctness(
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
_skip_if_insufficient_gpus_for_tp(tp_size)
|
||||
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
|
||||
@ -141,7 +141,25 @@ class CompilerManager:
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
cache = ast.literal_eval(f.read())
|
||||
|
||||
def check_type(value, ty):
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
|
||||
|
||||
def parse_key(key: Any) -> tuple[Range, int, str]:
|
||||
range_tuple, graph_index, compiler_name = key
|
||||
check_type(graph_index, int)
|
||||
check_type(compiler_name, str)
|
||||
if isinstance(range_tuple, tuple):
|
||||
start, end = range_tuple
|
||||
check_type(start, int)
|
||||
check_type(end, int)
|
||||
range_tuple = Range(start=start, end=end)
|
||||
check_type(range_tuple, Range)
|
||||
return range_tuple, graph_index, compiler_name
|
||||
|
||||
self.cache = {parse_key(key): value for key, value in cache.items()}
|
||||
|
||||
self.compiler.initialize_cache(
|
||||
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
||||
|
||||
@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
aot_context = nullcontext()
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
with aot_context:
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
@ -539,6 +539,11 @@ class ModelConfig:
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.mm_processor_cache_gb = 0
|
||||
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
|
||||
|
||||
# Init multimodal config if needed
|
||||
if self._model_info.supports_multimodal:
|
||||
if (
|
||||
|
||||
@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
)
|
||||
|
||||
|
||||
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||
def getattr_iter(
|
||||
object: object, names: Iterable[str], default: Any, warn: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
|
||||
In the case where the first name in `names` is the preferred name, and
|
||||
any other names are deprecated aliases, setting `warn=True` will log a
|
||||
warning when a deprecated name is used.
|
||||
"""
|
||||
for name in names:
|
||||
for i, name in enumerate(names):
|
||||
if hasattr(object, name):
|
||||
if warn and i > 0:
|
||||
logger.warning_once(
|
||||
"%s contains a deprecated attribute name '%s'. "
|
||||
"Please use the preferred attribute name '%s' instead.",
|
||||
type(object).__name__,
|
||||
name,
|
||||
names[0],
|
||||
)
|
||||
return getattr(object, name)
|
||||
return default
|
||||
|
||||
|
||||
@ -750,27 +750,17 @@ class VllmConfig:
|
||||
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
|
||||
self._set_compile_ranges()
|
||||
|
||||
if self.model_config and self.model_config.is_encoder_decoder:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
if (
|
||||
self.model_config
|
||||
and self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens,
|
||||
)
|
||||
if (
|
||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
|
||||
if (
|
||||
self.kv_events_config is not None
|
||||
|
||||
@ -27,7 +27,14 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
|
||||
LMCacheAsyncLookupServer,
|
||||
)
|
||||
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
||||
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
|
||||
|
||||
try:
|
||||
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
|
||||
except ImportError:
|
||||
# Backwards compatibility for lmcache <= 0.3.10-post1
|
||||
from lmcache.v1.plugin.plugin_launcher import (
|
||||
PluginLauncher as RuntimePluginLauncher,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@ -232,7 +232,177 @@ def parse_response_input(
|
||||
return msg
|
||||
|
||||
|
||||
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
|
||||
"""
|
||||
Parse a list of messages from request.messages in the Chat Completion API to
|
||||
Harmony messages.
|
||||
"""
|
||||
msgs: list[Message] = []
|
||||
tool_id_names: dict[str, str] = {}
|
||||
|
||||
# Collect tool id to name mappings for tool response recipient values
|
||||
for chat_msg in chat_msgs:
|
||||
for tool_call in chat_msg.get("tool_calls", []):
|
||||
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
|
||||
"name"
|
||||
)
|
||||
|
||||
for chat_msg in chat_msgs:
|
||||
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
|
||||
|
||||
msgs = auto_drop_analysis_messages(msgs)
|
||||
return msgs
|
||||
|
||||
|
||||
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
|
||||
"""
|
||||
Harmony models expect the analysis messages (representing raw chain of thought) to
|
||||
be dropped after an assistant message to the final channel is produced from the
|
||||
reasoning of those messages.
|
||||
|
||||
The openai-harmony library does this if the very last assistant message is to the
|
||||
final channel, but it does not handle the case where we're in longer multi-turn
|
||||
conversations and the client gave us reasoning content from previous turns of
|
||||
the conversation with multiple assistant messages to the final channel in the
|
||||
conversation.
|
||||
|
||||
So, we find the index of the last assistant message to the final channel and drop
|
||||
all analysis messages that precede it, leaving only the analysis messages that
|
||||
are relevant to the current part of the conversation.
|
||||
"""
|
||||
last_assistant_final_index = -1
|
||||
for i in range(len(msgs) - 1, -1, -1):
|
||||
msg = msgs[i]
|
||||
if msg.author.role == "assistant" and msg.channel == "final":
|
||||
last_assistant_final_index = i
|
||||
break
|
||||
|
||||
cleaned_msgs: list[Message] = []
|
||||
for i, msg in enumerate(msgs):
|
||||
if i < last_assistant_final_index and msg.channel == "analysis":
|
||||
continue
|
||||
cleaned_msgs.append(msg)
|
||||
|
||||
return cleaned_msgs
|
||||
|
||||
|
||||
def flatten_chat_text_content(content: str | list | None) -> str | None:
|
||||
"""
|
||||
Extract the text parts from a chat message content field and flatten them
|
||||
into a single string.
|
||||
"""
|
||||
if isinstance(content, list):
|
||||
return "".join(
|
||||
item.get("text", "")
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names: dict[str, str] | None = None
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Parse a message from request.messages in the Chat Completion API to
|
||||
Harmony messages.
|
||||
"""
|
||||
tool_id_names = tool_id_names or {}
|
||||
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
role = chat_msg.get("role")
|
||||
msgs: list[Message] = []
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls", [])
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
content = flatten_chat_text_content(chat_msg.get("content"))
|
||||
if content:
|
||||
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
|
||||
commentary_msg = commentary_msg.with_channel("commentary")
|
||||
msgs.append(commentary_msg)
|
||||
|
||||
reasoning_content = chat_msg.get("reasoning") or chat_msg.get(
|
||||
"reasoning_content"
|
||||
)
|
||||
if reasoning_content:
|
||||
analysis_msg = Message.from_role_and_content(
|
||||
Role.ASSISTANT, reasoning_content
|
||||
)
|
||||
analysis_msg = analysis_msg.with_channel("analysis")
|
||||
msgs.append(analysis_msg)
|
||||
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
# Officially, this should be `<|constrain|>json` but there is not clear
|
||||
# evidence that improves accuracy over `json` and some anecdotes to the
|
||||
# contrary. Further testing of the different content_types is needed.
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
tool_call_id = chat_msg.get("tool_call_id", "")
|
||||
name = tool_id_names.get(tool_call_id, "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
content = flatten_chat_text_content(content)
|
||||
|
||||
msg = (
|
||||
Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"), content
|
||||
)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("assistant")
|
||||
)
|
||||
return [msg]
|
||||
|
||||
# Non-tool reasoning content
|
||||
reasoning_content = chat_msg.get("reasoning") or chat_msg.get("reasoning_content")
|
||||
if role == "assistant" and reasoning_content:
|
||||
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_content)
|
||||
analysis_msg = analysis_msg.with_channel("analysis")
|
||||
msgs.append(analysis_msg)
|
||||
|
||||
# Default: user/assistant/system messages with content
|
||||
content = chat_msg.get("content") or ""
|
||||
if content is None:
|
||||
content = ""
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
|
||||
# Only add assistant messages if they have content, as reasoning or tool calling
|
||||
# assistant messages were already added above.
|
||||
if role == "assistant" and contents and contents[0].text:
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
# Send non-tool assistant messages to the final channel
|
||||
msg = msg.with_channel("final")
|
||||
msgs.append(msg)
|
||||
# For user/system/developer messages, add them directly even if no content.
|
||||
elif role != "assistant":
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
msgs.append(msg)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
|
||||
"""
|
||||
Parse a message from request.previous_input_messages in the Responsees API to
|
||||
Harmony messages.
|
||||
"""
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
if isinstance(content, list):
|
||||
# Handle array format for tool message content
|
||||
# by concatenating all text parts.
|
||||
content = "".join(
|
||||
item.get("text", "")
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
)
|
||||
content = flatten_chat_text_content(content)
|
||||
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"), content
|
||||
@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
||||
def parse_chat_output(
|
||||
token_ids: Sequence[int],
|
||||
) -> tuple[str | None, str | None, bool]:
|
||||
"""
|
||||
Parse the output of a Harmony chat completion into reasoning and final content.
|
||||
Note that when the `openai` tool parser is used, serving_chat only uses this
|
||||
for the reasoning content and gets the final content from the tool call parser.
|
||||
|
||||
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
|
||||
in use,this needs to return the final content without any tool calls parsed.
|
||||
|
||||
Empty reasoning or final content is returned as None instead of an empty string.
|
||||
"""
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
output_msgs = parser.messages
|
||||
is_tool_call = False # TODO: update this when tool call is supported
|
||||
if len(output_msgs) == 0:
|
||||
# The generation has stopped during reasoning.
|
||||
reasoning = parser.current_content
|
||||
final_content = None
|
||||
elif len(output_msgs) == 1:
|
||||
# The generation has stopped during final message.
|
||||
reasoning = output_msgs[0].content[0].text
|
||||
final_content = parser.current_content
|
||||
else:
|
||||
reasoning_msg = output_msgs[:-1]
|
||||
final_msg = output_msgs[-1]
|
||||
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg])
|
||||
final_content = final_msg.content[0].text
|
||||
|
||||
# Get completed messages from the parser
|
||||
reasoning_texts = [
|
||||
msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
|
||||
]
|
||||
final_texts = [
|
||||
msg.content[0].text for msg in output_msgs if msg.channel != "analysis"
|
||||
]
|
||||
|
||||
# Extract partial messages from the parser
|
||||
if parser.current_channel == "analysis" and parser.current_content:
|
||||
reasoning_texts.append(parser.current_content)
|
||||
elif parser.current_channel != "analysis" and parser.current_content:
|
||||
final_texts.append(parser.current_content)
|
||||
|
||||
# Flatten multiple messages into a single string
|
||||
reasoning: str | None = "\n".join(reasoning_texts)
|
||||
final_content: str | None = "\n".join(final_texts)
|
||||
|
||||
# Return None instead of empty string since existing callers check for None
|
||||
reasoning = reasoning or None
|
||||
final_content = final_content or None
|
||||
|
||||
return reasoning, final_content, is_tool_call
|
||||
|
||||
@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
get_stop_tokens_for_assistant_actions,
|
||||
get_streamable_parser_for_assistant,
|
||||
get_system_message,
|
||||
parse_chat_inputs_to_harmony_messages,
|
||||
parse_chat_output,
|
||||
parse_input_to_harmony_message,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -822,6 +822,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if delta_message is not None:
|
||||
harmony_tools_streamed[i] = True
|
||||
elif cur_channel == "commentary":
|
||||
# Tool call preambles meant to be shown to the user
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
else:
|
||||
delta_message = None
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
@ -1770,6 +1773,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
):
|
||||
messages: list[OpenAIMessage] = []
|
||||
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
|
||||
# Add system message.
|
||||
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||
# if the model supports it. TODO: Support browsing.
|
||||
@ -1788,8 +1796,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
for chat_msg in request.messages:
|
||||
messages.extend(parse_input_to_harmony_message(chat_msg))
|
||||
messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
|
||||
|
||||
# Render prompt token ids.
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
|
||||
@ -43,6 +43,7 @@ class OpenAIToolParser(ToolParser):
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
tool_calls = []
|
||||
final_content = None
|
||||
commentary_content = None
|
||||
|
||||
if len(parser.messages) > 0:
|
||||
for msg in parser.messages:
|
||||
@ -75,11 +76,15 @@ class OpenAIToolParser(ToolParser):
|
||||
)
|
||||
elif msg.channel == "final":
|
||||
final_content = msg_text
|
||||
elif msg.channel == "commentary" and not msg.recipient:
|
||||
commentary_content = msg_text
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=final_content,
|
||||
# prefer final content over commentary content if both are present
|
||||
# commentary content is tool call preambles meant to be shown to the user
|
||||
content=final_content or commentary_content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
|
||||
@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def gptq_marlin_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Construct a quant config for gptq marlin quantization.
|
||||
"""
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
|
||||
|
||||
# Activations are NOT quantized for GPTQ (fp16/bf16)
|
||||
a_shape = w_shape # Same as weight shape for alignment
|
||||
|
||||
# Determine weight dtype
|
||||
if weight_bits == 4:
|
||||
weight_dtype = "int4"
|
||||
elif weight_bits == 8:
|
||||
weight_dtype = torch.int8
|
||||
else:
|
||||
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
|
||||
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
|
||||
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
@ -700,6 +736,42 @@ def int4_w4afp8_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def awq_marlin_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: torch.Tensor | None,
|
||||
w2_zp: torch.Tensor | None,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for awq marlin quantization.
|
||||
"""
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
|
||||
|
||||
# Activations are NOT quantized for AWQ (fp16/bf16)
|
||||
a_shape = w_shape # Same as weight shape for alignment
|
||||
|
||||
# Determine weight dtype
|
||||
if weight_bits == 4:
|
||||
weight_dtype = "int4"
|
||||
elif weight_bits == 8:
|
||||
weight_dtype = torch.int8
|
||||
else:
|
||||
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
|
||||
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
|
||||
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def biased_moe_quant_config(
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
|
||||
@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1(
|
||||
m_indices_start_ptr = m_indices + cur_expert_start
|
||||
off_expert = tl.arange(0, BLOCK_E)
|
||||
|
||||
# any rows in the per-expert aligned region that do not correspond to
|
||||
# real tokens are left untouched here and should remain initialized to
|
||||
# -1 so DeepGEMM can skip them
|
||||
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
||||
offs = start_m + off_expert
|
||||
mask = offs < cur_expert_token_num
|
||||
tl.store(
|
||||
m_indices_start_ptr + start_m + off_expert,
|
||||
m_indices_start_ptr + offs,
|
||||
cur_expert,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@ -366,12 +372,17 @@ def deepgemm_moe_permute(
|
||||
(M_sum, H // block_k), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
maybe_has_empty_blocks = (expert_tokens_meta is None) or (
|
||||
expert_tokens_meta.expert_num_tokens_cpu is None
|
||||
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark
|
||||
# completely invalid / padded blocks that should be skipped. We always
|
||||
# initialize expert_ids to -1 so any row that is not explicitly written
|
||||
# by the scatter kernel will be treated as invalid and skipped by
|
||||
# DeepGEMM's scheduler.
|
||||
expert_ids = torch.full(
|
||||
(M_sum,),
|
||||
fill_value=-1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty
|
||||
|
||||
expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
|
||||
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
|
||||
|
||||
expert_num_tokens = None
|
||||
|
||||
@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
}
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full", intermediate_size_per_partition
|
||||
)
|
||||
self.is_k_full = intermediate_size_per_partition == intermediate_size_full
|
||||
|
||||
w13_qweight = Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but AWQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
|
||||
# Why does this take the intermediate size for size_k?
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
awq_marlin_moe_quant_config,
|
||||
)
|
||||
|
||||
return awq_marlin_moe_quant_config(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
group_size=self.quant_config.group_size,
|
||||
w1_zp=getattr(layer, "w13_qzeros", None)
|
||||
if self.quant_config.zero_point
|
||||
else None,
|
||||
w2_zp=getattr(layer, "w2_qzeros", None)
|
||||
if self.quant_config.zero_point
|
||||
else None,
|
||||
w1_bias=getattr(layer, "w13_bias", None),
|
||||
w2_bias=getattr(layer, "w2_bias", None),
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Select the GEMM implementation for AWQ-Marlin MoE.
|
||||
Returns MarlinExperts configured for AWQ quantization.
|
||||
This is ONLY used when LoRA is enabled.
|
||||
Without LoRA, AWQ uses its own apply() method.
|
||||
"""
|
||||
# Only use modular kernels when LoRA is enabled
|
||||
# Without LoRA, AWQ's own apply() method works fine and is more efficient
|
||||
if not self.moe.is_lora_enabled:
|
||||
raise NotImplementedError(
|
||||
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
|
||||
"Modular kernels are only used for LoRA support."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
# Ensure quant config is initialized
|
||||
assert self.moe_quant_config is not None, (
|
||||
"moe_quant_config must be initialized before select_gemm_impl"
|
||||
)
|
||||
|
||||
w13_g_idx = getattr(layer, "w13_g_idx", None)
|
||||
w2_g_idx = getattr(layer, "w2_g_idx", None)
|
||||
w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None)
|
||||
w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None)
|
||||
|
||||
# Check if using batched expert format (for Expert Parallelism)
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
# For batched format, use BatchedMarlinExperts
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
else:
|
||||
# Standard Marlin experts for AWQ
|
||||
return MarlinExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but GPTQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
gptq_marlin_moe_quant_config,
|
||||
)
|
||||
|
||||
return gptq_marlin_moe_quant_config(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
group_size=self.quant_config.group_size,
|
||||
w1_zp=getattr(layer, "w13_qzeros", None)
|
||||
if not self.quant_config.is_sym
|
||||
else None,
|
||||
w2_zp=getattr(layer, "w2_qzeros", None)
|
||||
if not self.quant_config.is_sym
|
||||
else None,
|
||||
w1_bias=getattr(layer, "w13_bias", None),
|
||||
w2_bias=getattr(layer, "w2_bias", None),
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Select the GEMM implementation for GPTQ-Marlin MoE.
|
||||
|
||||
Returns MarlinExperts configured for GPTQ quantization.
|
||||
This is ONLY used when LoRA is enabled.
|
||||
Without LoRA, GPTQ uses its own apply() method.
|
||||
"""
|
||||
# Only use modular kernels when LoRA is enabled
|
||||
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
|
||||
if not self.moe.is_lora_enabled:
|
||||
raise NotImplementedError(
|
||||
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
|
||||
"Modular kernels are only used for LoRA support."
|
||||
)
|
||||
|
||||
# The modular marlin kernels do not support 8-bit weights.
|
||||
if self.quant_config.weight_bits == 8:
|
||||
raise NotImplementedError(
|
||||
"GPTQ-Marlin kernel does not support 8-bit weights."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
# Ensure quant config is initialized
|
||||
assert self.moe_quant_config is not None, (
|
||||
"moe_quant_config must be initialized before select_gemm_impl"
|
||||
)
|
||||
|
||||
w13_g_idx = (
|
||||
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
|
||||
)
|
||||
w2_g_idx = (
|
||||
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
|
||||
)
|
||||
w13_g_idx_sort_indices = (
|
||||
getattr(layer, "w13_g_idx_sort_indices", None)
|
||||
if self.quant_config.desc_act
|
||||
else None
|
||||
)
|
||||
w2_g_idx_sort_indices = (
|
||||
getattr(layer, "w2_g_idx_sort_indices", None)
|
||||
if self.quant_config.desc_act
|
||||
else None
|
||||
)
|
||||
|
||||
# Check if using batched expert format (for Expert Parallelism)
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
# For batched format, use BatchedMarlinExperts
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
else:
|
||||
# Standard Marlin experts for GPTQ
|
||||
return MarlinExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -38,6 +38,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
prepare_static_weights_for_trtllm_fp4_moe,
|
||||
reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl,
|
||||
@ -1325,7 +1326,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"Accuracy may be affected."
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||
|
||||
# Common processing for input scales and alphas
|
||||
@ -1482,6 +1483,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@ -1500,11 +1505,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
and not layer.enable_eplb
|
||||
):
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
|
||||
)
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
@ -1522,6 +1524,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
# EPLB path
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
|
||||
@ -331,3 +331,82 @@ def flashinfer_trtllm_fp4_moe(
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_routed_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
|
||||
input top k expert indices and scores rather than computing
|
||||
top k expert indices from scores.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
topk_ids: Ids of selected experts
|
||||
top_k: Number of experts to select per token
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
# Pack top k ids and expert weights into a single int32 tensor, as
|
||||
# required by TRT-LLM
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).flatten(),
|
||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
@ -290,7 +290,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
if flashinfer_moe_backend in backend_map:
|
||||
if (
|
||||
flashinfer_moe_backend == "latency"
|
||||
and not current_platform.is_device_capability(100)
|
||||
and not current_platform.has_device_capability(100)
|
||||
):
|
||||
logger.info_once(
|
||||
"Flashinfer TRTLLM MOE backend is only supported on "
|
||||
|
||||
@ -25,7 +25,6 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
|
||||
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_parameters: dict[str, Any] | None = None,
|
||||
@ -54,12 +53,15 @@ def get_rope(
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
partial_rotary_factor = 1.0
|
||||
if rope_parameters is not None:
|
||||
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
rope_parameters = rope_parameters or {}
|
||||
base = rope_parameters.get("rope_theta", 10000)
|
||||
scaling_type = rope_parameters.get("rope_type", "default")
|
||||
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
|
||||
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
|
||||
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
|
||||
rotary_dim = int(head_size * partial_rotary_factor)
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
@ -72,7 +74,6 @@ def get_rope(
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
base = rope_parameters["rope_theta"] if rope_parameters else 10000
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
@ -88,109 +89,76 @@ def get_rope(
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif not rope_parameters:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_parameters:
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_parameters["mrope_section"],
|
||||
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "llama3":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
low_freq_factor = rope_parameters["low_freq_factor"]
|
||||
high_freq_factor = rope_parameters["high_freq_factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
rotary_emb = Llama3RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
scaling_factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position,
|
||||
)
|
||||
elif scaling_type == "mllama4":
|
||||
rotary_emb = Llama4VisionRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
else:
|
||||
scaling_type = rope_parameters["rope_type"]
|
||||
|
||||
if scaling_type == "llama3":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
low_freq_factor = rope_parameters["low_freq_factor"]
|
||||
high_freq_factor = rope_parameters["high_freq_factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
rotary_emb = Llama3RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
scaling_factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position,
|
||||
)
|
||||
elif scaling_type == "mllama4":
|
||||
rotary_emb = Llama4VisionRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_parameters:
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_parameters["mrope_section"],
|
||||
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
rotary_emb = LinearScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "ntk":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
mixed_b = rope_parameters.get("mixed_b")
|
||||
rotary_emb = NTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
mixed_b,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
if "alpha" in rope_parameters:
|
||||
scaling_alpha = rope_parameters["alpha"]
|
||||
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
)
|
||||
elif "factor" in rope_parameters:
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
|
||||
)
|
||||
elif scaling_type == "xdrope":
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
rotary_emb = LinearScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "ntk":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
mixed_b = rope_parameters.get("mixed_b")
|
||||
rotary_emb = NTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
mixed_b,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
if "alpha" in rope_parameters:
|
||||
scaling_alpha = rope_parameters["alpha"]
|
||||
rotary_emb = XDRotaryEmbedding(
|
||||
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
@ -198,67 +166,66 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
xdrope_section=rope_parameters["xdrope_section"],
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
elif "factor" in rope_parameters:
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_parameters.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"apply_yarn_scaling",
|
||||
"truncate",
|
||||
)
|
||||
}
|
||||
if "mrope_section" in rope_parameters:
|
||||
extra_kwargs.pop("apply_yarn_scaling", None)
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_parameters["mrope_section"],
|
||||
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
|
||||
scaling_factor=scaling_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_parameters.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
)
|
||||
}
|
||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
|
||||
)
|
||||
elif scaling_type == "xdrope":
|
||||
scaling_alpha = rope_parameters["alpha"]
|
||||
rotary_emb = XDRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
xdrope_section=rope_parameters["xdrope_section"],
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_parameters.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"apply_yarn_scaling",
|
||||
"truncate",
|
||||
)
|
||||
}
|
||||
if "mrope_section" in rope_parameters:
|
||||
extra_kwargs.pop("apply_yarn_scaling", None)
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_parameters["mrope_section"],
|
||||
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
|
||||
scaling_factor=scaling_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
@ -268,28 +235,55 @@ def get_rope(
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_parameters["short_factor"]
|
||||
long_factor = rope_parameters["long_factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_parameters.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
short_factor,
|
||||
long_factor,
|
||||
**extra_kwargs,
|
||||
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_parameters.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
}
|
||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_parameters["short_factor"]
|
||||
long_factor = rope_parameters["long_factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_parameters.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
short_factor,
|
||||
long_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
@ -241,9 +241,8 @@ class AfmoeAttention(nn.Module):
|
||||
if self.is_local_attention:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config["rope_parameters"],
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -226,7 +226,6 @@ class ApertusAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
|
||||
@ -314,7 +314,6 @@ class ArcticAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module):
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
|
||||
@ -127,11 +127,11 @@ class BailingAttention(nn.Module):
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
|
||||
rotary_dim = getattr(config, "rotary_dim", self.head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -178,14 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if hasattr(config, "attn_rotary_emb"):
|
||||
rotary_dim = config.attn_rotary_emb # for backward compatibility
|
||||
else:
|
||||
rotary_dim = self.head_dim # default
|
||||
rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module):
|
||||
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
|
||||
@ -99,13 +99,16 @@ class GLMAttention(nn.Module):
|
||||
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
||||
max_positions = getattr(config, "seq_length", 8192)
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio}
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": 10000 * rope_ratio,
|
||||
"partial_rotary_factor": 0.5,
|
||||
}
|
||||
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
|
||||
# which is equivalent to is_neox_style=True
|
||||
is_neox_style = not config.original_rope
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim // 2,
|
||||
max_position=max_positions,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
|
||||
@ -175,7 +175,6 @@ class CohereAttention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=False,
|
||||
|
||||
@ -42,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
config.hidden_act = "geglu"
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"rope_parameters": config.rope_parameters,
|
||||
}
|
||||
@ -77,9 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
if not model_config.enforce_eager:
|
||||
max_position = round_up(max_position, 8)
|
||||
|
||||
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
|
||||
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": max_position,
|
||||
"rope_parameters": config.rope_parameters,
|
||||
}
|
||||
@ -113,12 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
config.num_hidden_layers = config.n_layer
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
|
||||
max_trained_positions = getattr(config, "max_trained_positions", 2048)
|
||||
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": rotary_emb_dim,
|
||||
"max_position": max_trained_positions,
|
||||
"rope_parameters": config.rope_parameters,
|
||||
}
|
||||
@ -214,7 +215,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
tokens = getattr(config, "classifier_from_token", None)
|
||||
assert tokens is not None and len(tokens) == 2, (
|
||||
"Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
|
||||
)
|
||||
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
|
||||
|
||||
@ -240,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
config.hidden_act = "geglu"
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"rope_parameters": config.rope_parameters,
|
||||
}
|
||||
|
||||
@ -222,7 +222,6 @@ class DbrxAttention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -156,7 +156,6 @@ class DeepseekAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
@ -499,7 +498,6 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=False,
|
||||
@ -1018,7 +1016,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=False,
|
||||
@ -1038,7 +1035,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
if self.is_v32:
|
||||
self.indexer_rope_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -250,7 +250,6 @@ class Dots1Attention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -288,7 +288,6 @@ class Ernie4_5_MoeAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=False,
|
||||
|
||||
@ -167,7 +167,6 @@ class ExaoneAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
|
||||
@ -176,7 +176,6 @@ class Exaone4Attention(nn.Module):
|
||||
set_default_rope_theta(config, default_theta=1000000)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
|
||||
@ -167,7 +167,6 @@ class FalconAttention(nn.Module):
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -242,14 +242,11 @@ class FalconH1AttentionDecoderLayer(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if hasattr(config, "attn_rotary_emb"):
|
||||
rotary_dim = config.attn_rotary_emb # for backward compatibility
|
||||
else:
|
||||
rotary_dim = self.head_dim # default
|
||||
rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -174,7 +174,6 @@ class GemmaAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -152,7 +152,6 @@ class Gemma2Attention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -176,7 +176,6 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -384,7 +384,6 @@ class Gemma3nAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -81,7 +81,6 @@ class Glm4Attention(nn.Module):
|
||||
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.rotary_dim = self.head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@ -103,7 +102,6 @@ class Glm4Attention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=False,
|
||||
|
||||
@ -678,9 +678,9 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim // 2,
|
||||
max_position=8192,
|
||||
is_neox_style=True,
|
||||
rope_parameters={"partial_rotary_factor": 0.5},
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
|
||||
@ -285,7 +285,6 @@ class Glm4MoeAttention(nn.Module):
|
||||
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -95,12 +95,13 @@ class GPTJAttention(nn.Module):
|
||||
scaling = self.head_size**-0.5
|
||||
assert getattr(config, "rotary", True)
|
||||
assert config.rotary_dim % 2 == 0
|
||||
rope_parameters = getattr(config, "rope_parameters", {})
|
||||
rope_parameters["partial_rotary_factor"] = config.rotary_dim / self.head_size
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_size,
|
||||
rotary_dim=config.rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=getattr(config, "rope_parameters", None),
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = Attention(
|
||||
|
||||
@ -92,7 +92,6 @@ class GPTNeoXAttention(nn.Module):
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_size,
|
||||
rotary_dim=self.head_size,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -67,7 +67,6 @@ class OAIAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
dtype=torch.float32,
|
||||
rope_parameters={
|
||||
|
||||
@ -160,7 +160,6 @@ class GraniteAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -190,7 +190,6 @@ class GraniteMoeAttention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -271,7 +271,6 @@ class GraniteMoeHybridAttention(nn.Module):
|
||||
if config.position_embedding_type == "rope":
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -181,7 +181,6 @@ class Grok1Attention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -199,7 +199,6 @@ class HunYuanAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
@ -305,7 +304,6 @@ class HunYuanCrossAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -140,7 +140,6 @@ class InternLM2Attention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
|
||||
@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -259,7 +259,6 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=getattr(config, "rope_parameters", None),
|
||||
is_neox_style=is_neox_style,
|
||||
|
||||
@ -243,7 +243,6 @@ class Llama4Attention(nn.Module):
|
||||
self.rotary_emb = (
|
||||
get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
|
||||
@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
|
||||
@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.qk_rope_head_dim,
|
||||
rotary_dim=self.qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
if (
|
||||
rope_parameters is not None
|
||||
and "partial_rotary_factor" not in rope_parameters
|
||||
):
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
|
||||
@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module):
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_kv_heads: int,
|
||||
rotary_dim: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_parameters: dict | None = None,
|
||||
sliding_window: int | None = None,
|
||||
@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_dim = getattr(config, "rotary_dim", head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
|
||||
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
|
||||
max_position_embeddings = min(
|
||||
config.max_position_embeddings, config.max_model_len
|
||||
@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
rotary_dim=config.rotary_dim
|
||||
if hasattr(config, "rotary_dim")
|
||||
else head_dim,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
|
||||
@ -206,7 +206,6 @@ class MixtralAttention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
|
||||
@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module):
|
||||
rope_parameters = {
|
||||
"rope_type": "mllama4",
|
||||
"rope_theta": config.rope_parameters["rope_theta"],
|
||||
"partial_rotary_factor": 0.5,
|
||||
}
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
|
||||
# number of image patches
|
||||
max_position=(config.image_size // config.patch_size) ** 2,
|
||||
rope_parameters=rope_parameters,
|
||||
|
||||
@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
dtype=torch.float16,
|
||||
|
||||
@ -433,7 +433,6 @@ class MolmoAttention(nn.Module):
|
||||
# Rotary embeddings.
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -199,7 +199,6 @@ class NemotronAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
|
||||
@ -102,7 +102,6 @@ class OlmoAttention(nn.Module):
|
||||
# Rotary embeddings.
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user