mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 01:47:02 +08:00
Merge branch 'main' into rename_file_info_to_pkg/file
This commit is contained in:
commit
b4b79c5eba
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
||||
THRESHOLD=${1:-0.25}
|
||||
NUM_Q=${2:-1319}
|
||||
PORT=${3:-8030}
|
||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
||||
mkdir -p "${OUT_DIR}"
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 600 bash -c '
|
||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
||||
sleep 1
|
||||
done'
|
||||
}
|
||||
|
||||
MODEL="deepseek-ai/DeepSeek-V2-lite"
|
||||
|
||||
# Set BACKENDS based on platform
|
||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
||||
# ROCm platform
|
||||
BACKENDS=("allgather_reducescatter")
|
||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
||||
export VLLM_ROCM_MOE_PADDING=0
|
||||
else
|
||||
# Non-ROCm platform (CUDA/other)
|
||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
||||
fi
|
||||
|
||||
cleanup() {
|
||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
||||
kill "${SERVER_PID}" 2>/dev/null || true
|
||||
for _ in {1..20}; do
|
||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 2 \
|
||||
--data-parallel-size 2 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--port $PORT &
|
||||
SERVER_PID=$!
|
||||
wait_for_server $PORT
|
||||
|
||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
||||
OUT="${OUT_DIR}/${TAG}_${BACK}_async_eplb.json"
|
||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
||||
python3 - <<PY
|
||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
||||
PY
|
||||
|
||||
cleanup
|
||||
SERVER_PID=
|
||||
sleep 1
|
||||
PORT=$((PORT+1))
|
||||
done
|
||||
@ -50,6 +50,7 @@ for BACK in "${BACKENDS[@]}"; do
|
||||
--data-parallel-size 2 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--port $PORT &
|
||||
|
||||
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
||||
THRESHOLD=${1:-0.25}
|
||||
NUM_Q=${2:-1319}
|
||||
PORT=${3:-8040}
|
||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
||||
mkdir -p "${OUT_DIR}"
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 600 bash -c '
|
||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
||||
sleep 1
|
||||
done'
|
||||
}
|
||||
|
||||
MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||
|
||||
# Set BACKENDS based on platform
|
||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
||||
# ROCm platform
|
||||
BACKENDS=("allgather_reducescatter")
|
||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
||||
export VLLM_ROCM_MOE_PADDING=0
|
||||
else
|
||||
# Non-ROCm platform (CUDA/other)
|
||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
||||
fi
|
||||
|
||||
cleanup() {
|
||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
||||
kill "${SERVER_PID}" 2>/dev/null || true
|
||||
for _ in {1..20}; do
|
||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--port $PORT &
|
||||
SERVER_PID=$!
|
||||
wait_for_server $PORT
|
||||
|
||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
||||
OUT="${OUT_DIR}/${TAG}_${BACK}.json"
|
||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
||||
python3 - <<PY
|
||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
||||
PY
|
||||
|
||||
cleanup
|
||||
SERVER_PID=
|
||||
sleep 1
|
||||
PORT=$((PORT+1))
|
||||
done
|
||||
@ -1373,4 +1373,22 @@ steps:
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
|
||||
- label: DeepSeek V2-Lite Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
|
||||
|
||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
@ -65,7 +65,6 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
|
||||
# Centralized v1 package - copied to both test and final stages
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
|
||||
|
||||
# -----------------------
|
||||
@ -98,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system hf_transfer
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
||||
|
||||
# Copy in the v1 package
|
||||
# Copy in the v1 package (for python-only install test group)
|
||||
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
|
||||
|
||||
# Source code is used in the `python_only_compile.sh` test
|
||||
@ -130,9 +129,6 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
&& pip uninstall -y vllm \
|
||||
&& uv pip install --system *.whl
|
||||
|
||||
# Copy in the v1 package
|
||||
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
|
||||
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Copy over the benchmark scripts as well
|
||||
|
||||
@ -108,6 +108,116 @@ networks.
|
||||
Consult your operating system or application platform documentation for specific
|
||||
firewall configuration instructions.
|
||||
|
||||
## API Key Authentication Limitations
|
||||
|
||||
### Overview
|
||||
|
||||
The `--api-key` flag (or `VLLM_API_KEY` environment variable) provides authentication for vLLM's HTTP server, but **only for OpenAI-compatible API endpoints under the `/v1` path prefix**. Many other sensitive endpoints are exposed on the same HTTP server without any authentication enforcement.
|
||||
|
||||
**Important:** Do not rely exclusively on `--api-key` for securing access to vLLM. Additional security measures are required for production deployments.
|
||||
|
||||
### Protected Endpoints (Require API Key)
|
||||
|
||||
When `--api-key` is configured, the following `/v1` endpoints require Bearer token authentication:
|
||||
|
||||
- `/v1/models` - List available models
|
||||
- `/v1/chat/completions` - Chat completions
|
||||
- `/v1/completions` - Text completions
|
||||
- `/v1/embeddings` - Generate embeddings
|
||||
- `/v1/audio/transcriptions` - Audio transcription
|
||||
- `/v1/audio/translations` - Audio translation
|
||||
- `/v1/messages` - Anthropic-compatible messages API
|
||||
- `/v1/responses` - Response management
|
||||
- `/v1/score` - Scoring API
|
||||
- `/v1/rerank` - Reranking API
|
||||
|
||||
### Unprotected Endpoints (No API Key Required)
|
||||
|
||||
The following endpoints **do not require authentication** even when `--api-key` is configured:
|
||||
|
||||
**Inference endpoints:**
|
||||
|
||||
- `/invocations` - SageMaker-compatible endpoint (routes to the same inference functions as `/v1` endpoints)
|
||||
- `/inference/v1/generate` - Generate completions
|
||||
- `/pooling` - Pooling API
|
||||
- `/classify` - Classification API
|
||||
- `/score` - Scoring API (non-`/v1` variant)
|
||||
- `/rerank` - Reranking API (non-`/v1` variant)
|
||||
|
||||
**Operational control endpoints (always enabled):**
|
||||
|
||||
- `/pause` - Pause generation (causes denial of service)
|
||||
- `/resume` - Resume generation
|
||||
- `/scale_elastic_ep` - Trigger scaling operations
|
||||
|
||||
**Utility endpoints:**
|
||||
|
||||
- `/tokenize` - Tokenize text
|
||||
- `/detokenize` - Detokenize tokens
|
||||
- `/health` - Health check
|
||||
- `/ping` - SageMaker health check
|
||||
- `/version` - Version information
|
||||
- `/load` - Server load metrics
|
||||
|
||||
**Tokenizer information endpoint (only when `--enable-tokenizer-info-endpoint` is set):**
|
||||
|
||||
This endpoint is **only available when the `--enable-tokenizer-info-endpoint` flag is set**. It may expose sensitive information such as chat templates and tokenizer configuration:
|
||||
|
||||
- `/tokenizer_info` - Get comprehensive tokenizer information including chat templates and configuration
|
||||
|
||||
**Development endpoints (only when `VLLM_SERVER_DEV_MODE=1`):**
|
||||
|
||||
These endpoints are **only available when the environment variable `VLLM_SERVER_DEV_MODE` is set to `1`**. They are intended for development and debugging purposes and should never be enabled in production:
|
||||
|
||||
- `/server_info` - Get detailed server configuration
|
||||
- `/reset_prefix_cache` - Reset prefix cache (can disrupt service)
|
||||
- `/reset_mm_cache` - Reset multimodal cache (can disrupt service)
|
||||
- `/sleep` - Put engine to sleep (causes denial of service)
|
||||
- `/wake_up` - Wake engine from sleep
|
||||
- `/is_sleeping` - Check if engine is sleeping
|
||||
- `/collective_rpc` - Execute arbitrary RPC methods on the engine (extremely dangerous)
|
||||
|
||||
**Profiler endpoints (only when `VLLM_TORCH_PROFILER_DIR` or `VLLM_TORCH_CUDA_PROFILE` are set):**
|
||||
|
||||
These endpoints are only available when profiling is enabled and should only be used for local development:
|
||||
|
||||
- `/start_profile` - Start PyTorch profiler
|
||||
- `/stop_profile` - Stop PyTorch profiler
|
||||
|
||||
**Note:** The `/invocations` endpoint is particularly concerning as it provides unauthenticated access to the same inference capabilities as the protected `/v1` endpoints.
|
||||
|
||||
### Security Implications
|
||||
|
||||
An attacker who can reach the vLLM HTTP server can:
|
||||
|
||||
1. **Bypass authentication** by using non-`/v1` endpoints like `/invocations`, `/inference/v1/generate`, `/pooling`, `/classify`, `/score`, or `/rerank` to run arbitrary inference without credentials
|
||||
2. **Cause denial of service** by calling `/pause` or `/scale_elastic_ep` without a token
|
||||
3. **Access operational controls** to manipulate server state (e.g., pausing generation)
|
||||
4. **If `--enable-tokenizer-info-endpoint` is set:** Access sensitive tokenizer configuration including chat templates, which may reveal prompt engineering strategies or other implementation details
|
||||
5. **If `VLLM_SERVER_DEV_MODE=1` is set:** Execute arbitrary RPC commands via `/collective_rpc`, reset caches, put the engine to sleep, and access detailed server configuration
|
||||
|
||||
### Recommended Security Practices
|
||||
|
||||
#### 1. Minimize Exposed Endpoints
|
||||
|
||||
**CRITICAL:** Never set `VLLM_SERVER_DEV_MODE=1` in production environments. Development endpoints expose extremely dangerous functionality including:
|
||||
|
||||
- Arbitrary RPC execution via `/collective_rpc`
|
||||
- Cache manipulation that can disrupt service
|
||||
- Detailed server configuration disclosure
|
||||
|
||||
Similarly, never enable profiler endpoints (`VLLM_TORCH_PROFILER_DIR` or `VLLM_TORCH_CUDA_PROFILE`) in production.
|
||||
|
||||
**Be cautious with `--enable-tokenizer-info-endpoint`:** Only enable the `/tokenizer_info` endpoint if you need to expose tokenizer configuration information. This endpoint reveals chat templates and tokenizer settings that may contain sensitive implementation details or prompt engineering strategies.
|
||||
|
||||
#### 2. Deploy Behind a Reverse Proxy
|
||||
|
||||
The most effective approach is to deploy vLLM behind a reverse proxy (such as nginx, Envoy, or a Kubernetes Gateway) that:
|
||||
|
||||
- Explicitly allowlists only the endpoints you want to expose to end users
|
||||
- Blocks all other endpoints, including the unauthenticated inference and operational control endpoints
|
||||
- Implements additional authentication, rate limiting, and logging at the proxy layer
|
||||
|
||||
## Reporting Security Vulnerabilities
|
||||
|
||||
If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md).
|
||||
|
||||
@ -309,6 +309,28 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# HunyuanOCR
|
||||
def load_hunyuan_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "tencent/HunyuanOCR"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholder = (
|
||||
"<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
|
||||
) * len(image_urls)
|
||||
prompt = f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>"
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_hyperclovax_seed_vision(
|
||||
question: str, image_urls: list[str]
|
||||
) -> ModelRequestData:
|
||||
@ -1322,6 +1344,7 @@ model_example_map = {
|
||||
"deepseek_ocr": load_deepseek_ocr,
|
||||
"gemma3": load_gemma3,
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"hunyuan_vl": load_hunyuan_vl,
|
||||
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
|
||||
"idefics3": load_idefics3,
|
||||
"interns1": load_interns1,
|
||||
|
||||
@ -70,8 +70,8 @@ torchgeo==0.7.0
|
||||
mteb==2.1.2
|
||||
|
||||
# Data processing
|
||||
xgrammar @ git+https://github.com/mlc-ai/xgrammar.git@eafd4db51b78acc64b3f0764ef27dfd206c28628
|
||||
# Test async scheduling
|
||||
xgrammar==0.1.27
|
||||
# Test async scheduling
|
||||
|
||||
# Utilities
|
||||
num2words==0.5.14
|
||||
|
||||
@ -326,7 +326,7 @@ def async_tp_pass_on_test_model(
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
enable_async_tp=True,
|
||||
fuse_gemm_comms=True,
|
||||
),
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
@ -413,7 +413,7 @@ def test_async_tp_pass_correctness(
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [2, 4, 8],
|
||||
"splitting_ops": [],
|
||||
"pass_config": {"enable_async_tp": async_tp_enabled},
|
||||
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
|
||||
}
|
||||
|
||||
async_tp_args = [
|
||||
|
||||
@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
)
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True, enable_noop=True
|
||||
fuse_allreduce_rms=True, eliminate_noops=True
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
|
||||
|
||||
@ -192,7 +192,7 @@ def test_attn_quant(
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
@ -282,9 +282,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_fi_allreduce_fusion=True,
|
||||
fuse_attn_quant=True,
|
||||
eliminate_noops=True,
|
||||
fuse_allreduce_rms=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
@ -384,10 +384,10 @@ def test_tp2_attn_quant_async_tp(
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_sequence_parallelism=True,
|
||||
enable_async_tp=True,
|
||||
fuse_attn_quant=True,
|
||||
eliminate_noops=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
|
||||
@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
if self.vllm_config.compilation_config.pass_config.enable_fusion:
|
||||
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||
elif RMSNorm.enabled():
|
||||
return [
|
||||
@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
|
||||
@pytest.mark.parametrize("dynamic", [False, True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(
|
||||
@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
num_processes = 2
|
||||
@ -211,7 +211,7 @@ def test_sequence_parallelism_pass(
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_fusion,
|
||||
fuse_norm_quant,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
|
||||
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||
custom_ops=custom_ops_list,
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
enable_sp=True,
|
||||
fuse_norm_quant=fuse_norm_quant,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
) # NoOp needed for fusion
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
sequence_parallelism_pass,
|
||||
]
|
||||
|
||||
if enable_fusion:
|
||||
if fuse_norm_quant:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
|
||||
@ -122,7 +122,9 @@ def test_full_graph(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
),
|
||||
*model_info,
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -10,8 +11,9 @@ from pydantic import ValidationError
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.logger import _print_warning_once
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
@ -191,7 +193,7 @@ def test_splitting_ops_dynamic():
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
@ -206,7 +208,7 @@ def test_splitting_ops_dynamic():
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
@ -219,7 +221,7 @@ def test_splitting_ops_dynamic():
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
@ -227,7 +229,7 @@ def test_splitting_ops_dynamic():
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# fuse_attn_quant is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
@ -301,7 +303,7 @@ def test_should_split():
|
||||
"cudagraph_capture_sizes",
|
||||
"max_cudagraph_capture_size",
|
||||
"tp_size",
|
||||
"enable_sequence_parallelism",
|
||||
"enable_sp",
|
||||
"max_num_batched_tokens",
|
||||
"cudagraph_mode",
|
||||
"expected_max_size",
|
||||
@ -339,7 +341,7 @@ def test_cudagraph_sizes_post_init(
|
||||
cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size,
|
||||
tp_size,
|
||||
enable_sequence_parallelism,
|
||||
enable_sp,
|
||||
max_num_batched_tokens,
|
||||
cudagraph_mode,
|
||||
expected_max_size,
|
||||
@ -355,11 +357,12 @@ def test_cudagraph_sizes_post_init(
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||
pass_config={
|
||||
"enable_sequence_parallelism": enable_sequence_parallelism,
|
||||
"enable_fusion": True,
|
||||
"enable_noop": True,
|
||||
},
|
||||
pass_config=PassConfig(
|
||||
enable_sp=enable_sp,
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
)
|
||||
engine_args = EngineArgs(
|
||||
@ -375,3 +378,53 @@ def test_cudagraph_sizes_post_init(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
|
||||
def test_pass_config_deprecation(caplog_vllm):
|
||||
caplog_vllm.set_level(logging.WARNING)
|
||||
|
||||
# Clear cache to ensure warnings are re-issued
|
||||
_print_warning_once.cache_clear()
|
||||
|
||||
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fusion=True)
|
||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_norm_quant is True
|
||||
assert config.fuse_act_quant is True
|
||||
assert config.enable_fusion is None
|
||||
|
||||
# Test enable_attn_fusion -> fuse_attn_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_attn_fusion=True)
|
||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_attn_quant is True
|
||||
assert config.enable_attn_fusion is None
|
||||
|
||||
# Test enable_noop -> eliminate_noops
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_noop=True)
|
||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||
assert config.eliminate_noops is True
|
||||
assert config.enable_noop is None
|
||||
|
||||
# Test enable_sequence_parallelism -> enable_sp
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_sequence_parallelism=True)
|
||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||
assert config.enable_sp is True
|
||||
assert config.enable_sequence_parallelism is None
|
||||
|
||||
# Test enable_async_tp -> fuse_gemm_comms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_async_tp=True)
|
||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_gemm_comms is True
|
||||
assert config.enable_async_tp is None
|
||||
|
||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_allreduce_rms is True
|
||||
assert config.enable_fi_allreduce_fusion is None
|
||||
|
||||
@ -223,7 +223,11 @@ def test_fix_functionalization(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["all"],
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=do_fusion,
|
||||
fuse_act_quant=do_fusion,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -159,7 +159,9 @@ def test_fusion_rmsnorm_quant(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
),
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
|
||||
@ -373,7 +373,7 @@ def test_attention_quant_pattern(
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_attn_fusion=True, enable_noop=True
|
||||
fuse_attn_quant=True, eliminate_noops=True
|
||||
)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
|
||||
@ -51,7 +51,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
pass_config=PassConfig(eliminate_noops=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
@ -99,7 +99,7 @@ def test_non_noop_slice_preserved():
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
pass_config=PassConfig(eliminate_noops=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
|
||||
@ -64,8 +64,11 @@ def test_pass_manager_uuid(callable):
|
||||
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.compilation_config.pass_config.enable_fusion = (
|
||||
not config2.compilation_config.pass_config.enable_fusion
|
||||
config2.compilation_config.pass_config.fuse_norm_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_norm_quant
|
||||
)
|
||||
config2.compilation_config.pass_config.fuse_act_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_act_quant
|
||||
)
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
|
||||
@ -140,7 +140,7 @@ def test_qk_norm_rope_fusion(
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_noop=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@ -168,7 +168,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -32,7 +32,8 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
enable_fusion: bool
|
||||
fuse_norm_quant: bool
|
||||
fuse_act_quant: bool
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
@ -66,7 +67,8 @@ class SPTestSettings:
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
fuse_norm_quant=False,
|
||||
fuse_act_quant=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
@ -97,7 +99,8 @@ class SPTestSettings:
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
fuse_norm_quant=False,
|
||||
fuse_act_quant=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
@ -126,7 +129,8 @@ class SPTestSettings:
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
enable_fusion=fusion_val,
|
||||
fuse_norm_quant=fusion_val,
|
||||
fuse_act_quant=fusion_val,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False,
|
||||
)
|
||||
@ -162,7 +166,7 @@ def _compare_sp(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available: int,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
fuse_gemm_comms: bool,
|
||||
*,
|
||||
method: Literal["generate", "encode"],
|
||||
is_multimodal: bool,
|
||||
@ -170,7 +174,8 @@ def _compare_sp(
|
||||
(
|
||||
tp_size,
|
||||
pp_size,
|
||||
enable_fusion,
|
||||
fuse_norm_quant,
|
||||
fuse_act_quant,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
@ -248,10 +253,11 @@ def _compare_sp(
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [4, 8],
|
||||
"pass_config": {
|
||||
"enable_sequence_parallelism": True,
|
||||
"enable_async_tp": enable_async_tp,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_noop": True,
|
||||
"enable_sp": True,
|
||||
"fuse_gemm_comms": fuse_gemm_comms,
|
||||
"fuse_norm_quant": fuse_norm_quant,
|
||||
"fuse_act_quant": fuse_act_quant,
|
||||
"eliminate_noops": True,
|
||||
},
|
||||
"use_inductor_graph_partition": use_inductor_graph_partition,
|
||||
}
|
||||
@ -309,7 +315,7 @@ SP_TEST_MODELS = [
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
|
||||
@pytest.mark.parametrize("fuse_gemm_comms", [False]) # TODO: enable async TP
|
||||
@create_new_process_for_each_test()
|
||||
def test_tp_sp_generation(
|
||||
model_id: str,
|
||||
@ -319,7 +325,7 @@ def test_tp_sp_generation(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
fuse_gemm_comms: bool,
|
||||
):
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
@ -328,7 +334,7 @@ def test_tp_sp_generation(
|
||||
if (
|
||||
"fp8" in model_id.lower()
|
||||
and current_platform.get_device_capability() < (9, 0)
|
||||
and (not enable_async_tp)
|
||||
and (not fuse_gemm_comms)
|
||||
):
|
||||
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
|
||||
|
||||
@ -340,7 +346,7 @@ def test_tp_sp_generation(
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition,
|
||||
enable_async_tp=enable_async_tp,
|
||||
fuse_gemm_comms=fuse_gemm_comms,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_engine_dead_error():
|
||||
# Import the health function directly to test it in isolation
|
||||
from vllm.entrypoints.openai.api_server import health
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
|
||||
# Create a mock request that simulates what FastAPI would provide
|
||||
mock_request = Mock(spec=Request)
|
||||
|
||||
@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str):
|
||||
assert response.status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_enable_response_messages(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Hello?",
|
||||
extra_body={"enable_response_messages": True},
|
||||
)
|
||||
assert response.status == "completed"
|
||||
assert response.input_messages[0]["type"] == "raw_message_tokens"
|
||||
assert type(response.input_messages[0]["message"]) is str
|
||||
assert len(response.input_messages[0]["message"]) > 10
|
||||
assert type(response.input_messages[0]["tokens"][0]) is int
|
||||
assert type(response.output_messages[0]["message"]) is str
|
||||
assert len(response.output_messages[0]["message"]) > 10
|
||||
assert type(response.output_messages[0]["tokens"][0]) is int
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_reasoning_item(client: OpenAI, model_name: str):
|
||||
|
||||
19
tests/models/multimodal/generation/conftest.py
Normal file
19
tests/models/multimodal/generation/conftest.py
Normal file
@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM tests."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Disable Flash/MemEfficient SDP on ROCm to avoid HF
|
||||
Transformers accuracy issues.
|
||||
"""
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
@ -137,7 +137,7 @@ VLM_TEST_SETTINGS = {
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
"qwen2_5_omni": VLMTestInfo(
|
||||
@ -152,7 +152,7 @@ VLM_TEST_SETTINGS = {
|
||||
auto_cls=AutoModelForTextToWaveform,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
"qwen3_vl": VLMTestInfo(
|
||||
@ -173,7 +173,7 @@ VLM_TEST_SETTINGS = {
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[
|
||||
pytest.mark.core_model,
|
||||
],
|
||||
@ -350,7 +350,7 @@ VLM_TEST_SETTINGS = {
|
||||
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
|
||||
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
|
||||
stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"],
|
||||
image_size_factors=[(), (1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
|
||||
image_size_factors=[(1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
|
||||
),
|
||||
"fuyu": VLMTestInfo(
|
||||
models=["adept/fuyu-8b"],
|
||||
@ -707,7 +707,7 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
image_size_factors=[(), (0.25,)],
|
||||
image_size_factors=[(0.25,)],
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_VERSION) == Version("4.57.3"),
|
||||
@ -760,7 +760,7 @@ VLM_TEST_SETTINGS = {
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.cpu_model],
|
||||
),
|
||||
"skywork_r1v": VLMTestInfo(
|
||||
@ -812,7 +812,7 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.skip("Model initialization hangs")],
|
||||
),
|
||||
### Tensor parallel / multi-gpu broadcast tests
|
||||
|
||||
@ -62,6 +62,65 @@ def get_filtered_test_settings(
|
||||
return matching_tests
|
||||
|
||||
|
||||
def get_model_type_cases(
|
||||
model_type: str,
|
||||
test_info: VLMTestInfo,
|
||||
test_type: VLMTestType,
|
||||
):
|
||||
# Ensure that something is wrapped as an iterable it's not already
|
||||
ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,)
|
||||
|
||||
# This is essentially the same as nesting a bunch of mark.parametrize
|
||||
# decorators, but we do it programmatically to allow overrides for on
|
||||
# a per-model basis, while still being able to execute each of these
|
||||
# as individual test cases in pytest.
|
||||
iter_kwargs = OrderedDict(
|
||||
[
|
||||
("model", ensure_wrapped(test_info.models)),
|
||||
("max_tokens", ensure_wrapped(test_info.max_tokens)),
|
||||
("num_logprobs", ensure_wrapped(test_info.num_logprobs)),
|
||||
("dtype", ensure_wrapped(test_info.dtype)),
|
||||
(
|
||||
"distributed_executor_backend",
|
||||
ensure_wrapped(test_info.distributed_executor_backend),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# num_frames is video only
|
||||
if test_type == VLMTestType.VIDEO:
|
||||
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
|
||||
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
|
||||
test_info.needs_video_metadata
|
||||
)
|
||||
|
||||
# No sizes passed for custom inputs, since inputs are directly provided
|
||||
if test_type not in (
|
||||
VLMTestType.CUSTOM_INPUTS,
|
||||
VLMTestType.AUDIO,
|
||||
):
|
||||
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
|
||||
if wrapped_sizes is None:
|
||||
raise ValueError(f"Sizes must be set for test type {test_type}")
|
||||
iter_kwargs["size_wrapper"] = wrapped_sizes
|
||||
|
||||
# Otherwise expand the custom test options instead
|
||||
elif test_type == VLMTestType.CUSTOM_INPUTS:
|
||||
if test_info.custom_test_opts is None:
|
||||
raise ValueError("Test has type CUSTOM_INPUTS, but none given")
|
||||
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
|
||||
|
||||
# Wrap all model cases in a pytest parameter & pass marks through
|
||||
return [
|
||||
pytest.param(
|
||||
model_type,
|
||||
ExpandableVLMTestArgs(**{k: v for k, v in zip(iter_kwargs.keys(), case)}),
|
||||
marks=test_info.marks if test_info.marks is not None else [],
|
||||
)
|
||||
for case in list(itertools.product(*iter_kwargs.values()))
|
||||
]
|
||||
|
||||
|
||||
def get_parametrized_options(
|
||||
test_settings: dict[str, VLMTestInfo],
|
||||
test_type: VLMTestType,
|
||||
@ -76,64 +135,11 @@ def get_parametrized_options(
|
||||
test_settings, test_type, create_new_process_for_each_test
|
||||
)
|
||||
|
||||
# Ensure that something is wrapped as an iterable it's not already
|
||||
ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,)
|
||||
|
||||
def get_model_type_cases(model_type: str, test_info: VLMTestInfo):
|
||||
# This is essentially the same as nesting a bunch of mark.parametrize
|
||||
# decorators, but we do it programmatically to allow overrides for on
|
||||
# a per-model basis, while still being able to execute each of these
|
||||
# as individual test cases in pytest.
|
||||
iter_kwargs = OrderedDict(
|
||||
[
|
||||
("model", ensure_wrapped(test_info.models)),
|
||||
("max_tokens", ensure_wrapped(test_info.max_tokens)),
|
||||
("num_logprobs", ensure_wrapped(test_info.num_logprobs)),
|
||||
("dtype", ensure_wrapped(test_info.dtype)),
|
||||
(
|
||||
"distributed_executor_backend",
|
||||
ensure_wrapped(test_info.distributed_executor_backend),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# num_frames is video only
|
||||
if test_type == VLMTestType.VIDEO:
|
||||
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
|
||||
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
|
||||
test_info.needs_video_metadata
|
||||
)
|
||||
|
||||
# No sizes passed for custom inputs, since inputs are directly provided
|
||||
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
|
||||
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
|
||||
if wrapped_sizes is None:
|
||||
raise ValueError(f"Sizes must be set for test type {test_type}")
|
||||
iter_kwargs["size_wrapper"] = wrapped_sizes
|
||||
|
||||
# Otherwise expand the custom test options instead
|
||||
elif test_type == VLMTestType.CUSTOM_INPUTS:
|
||||
if test_info.custom_test_opts is None:
|
||||
raise ValueError("Test has type CUSTOM_INPUTS, but none given")
|
||||
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
|
||||
|
||||
# Wrap all model cases in a pytest parameter & pass marks through
|
||||
return [
|
||||
pytest.param(
|
||||
model_type,
|
||||
ExpandableVLMTestArgs(
|
||||
**{k: v for k, v in zip(iter_kwargs.keys(), case)}
|
||||
),
|
||||
marks=test_info.marks if test_info.marks is not None else [],
|
||||
)
|
||||
for case in list(itertools.product(*iter_kwargs.values()))
|
||||
]
|
||||
|
||||
# Get a list per model type, where each entry contains a tuple of all of
|
||||
# that model type's cases, then flatten them into the top level so that
|
||||
# we can consume them in one mark.parametrize call.
|
||||
cases_by_model_type = [
|
||||
get_model_type_cases(model_type, test_info)
|
||||
get_model_type_cases(model_type, test_info, test_type)
|
||||
for model_type, test_info in matching_tests.items()
|
||||
]
|
||||
return list(itertools.chain(*cases_by_model_type))
|
||||
|
||||
@ -50,8 +50,8 @@ MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PL
|
||||
VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?"
|
||||
|
||||
|
||||
IMAGE_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)]
|
||||
EMBEDDING_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0)]
|
||||
IMAGE_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)]
|
||||
EMBEDDING_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0)]
|
||||
RunnerOutput = tuple[list[int], str, SampleLogprobs | None]
|
||||
|
||||
|
||||
|
||||
@ -47,6 +47,12 @@ QWEN2_CONFIG = GGUFTestConfig(
|
||||
gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
|
||||
)
|
||||
|
||||
QWEN3_CONFIG = GGUFTestConfig(
|
||||
original_model="Qwen/Qwen3-0.6B",
|
||||
gguf_repo="unsloth/Qwen3-0.6B-GGUF",
|
||||
gguf_filename="Qwen3-0.6B-BF16.gguf",
|
||||
)
|
||||
|
||||
PHI3_CONFIG = GGUFTestConfig(
|
||||
original_model="microsoft/Phi-3.5-mini-instruct",
|
||||
gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
|
||||
@ -87,6 +93,7 @@ GEMMA3_CONFIG = GGUFTestConfig(
|
||||
MODELS = [
|
||||
# LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458
|
||||
QWEN2_CONFIG,
|
||||
QWEN3_CONFIG,
|
||||
PHI3_CONFIG,
|
||||
GPT2_CONFIG,
|
||||
STABLELM_CONFIG,
|
||||
|
||||
@ -1023,17 +1023,17 @@ def test_vllm_config_explicit_overrides():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
|
||||
# Explicit pass config flags to override defaults
|
||||
pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
|
||||
pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True)
|
||||
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||
config = VllmConfig(
|
||||
optimization_level=OptimizationLevel.O0,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert config.compilation_config.pass_config.enable_noop is True
|
||||
assert config.compilation_config.pass_config.enable_attn_fusion is True
|
||||
assert config.compilation_config.pass_config.eliminate_noops is True
|
||||
assert config.compilation_config.pass_config.fuse_attn_quant is True
|
||||
|
||||
# Explicit cudagraph mode override on quantized model at O2
|
||||
pass_config = PassConfig(enable_async_tp=True)
|
||||
pass_config = PassConfig(fuse_gemm_comms=True)
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
|
||||
)
|
||||
@ -1043,7 +1043,7 @@ def test_vllm_config_explicit_overrides():
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
assert config.compilation_config.pass_config.enable_async_tp is True
|
||||
assert config.compilation_config.pass_config.fuse_gemm_comms is True
|
||||
# Mode should still use default for O2
|
||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
|
||||
@ -1093,7 +1093,7 @@ def test_vllm_config_explicit_overrides():
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
# Explicit override should be respected
|
||||
assert config.compilation_config.pass_config.enable_noop is False
|
||||
assert config.compilation_config.pass_config.eliminate_noops is False
|
||||
# Other fields should still use defaults
|
||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any
|
||||
@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CUDAGraphStat:
|
||||
num_unpadded_tokens: int
|
||||
num_padded_tokens: int
|
||||
num_paddings: int
|
||||
runtime_mode: str
|
||||
|
||||
|
||||
class CUDAGraphLogging:
|
||||
"""Aggregate and log cudagraph metrics"""
|
||||
|
||||
COLUMN_HEADERS = [
|
||||
"Unpadded Tokens",
|
||||
"Padded Tokens",
|
||||
"Num Paddings",
|
||||
"Runtime Mode",
|
||||
"Count",
|
||||
]
|
||||
|
||||
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
|
||||
self.reset()
|
||||
self.cg_mode = str(cg_mode)
|
||||
self.cg_capture_sizes = str(cg_capture_sizes or [])
|
||||
|
||||
self.settings_header = (
|
||||
"**CUDAGraph Config Settings:**\n\n"
|
||||
f"- Mode: {self.cg_mode}\n"
|
||||
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
|
||||
"**CUDAGraph Stats:**\n\n"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.stats = []
|
||||
|
||||
def observe(self, cudagraph_stat: CUDAGraphStat):
|
||||
self.stats.append(cudagraph_stat)
|
||||
|
||||
def generate_metric_table(self) -> str:
|
||||
stats_counts = Counter(self.stats)
|
||||
|
||||
# Convert stats to rows of strings, in descending order of observed frequencies
|
||||
rows = []
|
||||
for stat, count in sorted(
|
||||
stats_counts.items(), key=lambda item: item[1], reverse=True
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
str(stat.num_unpadded_tokens),
|
||||
str(stat.num_padded_tokens),
|
||||
str(stat.num_paddings),
|
||||
stat.runtime_mode,
|
||||
str(count),
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate column widths (max of header and data)
|
||||
col_widths = []
|
||||
for i, header_text in enumerate(self.COLUMN_HEADERS):
|
||||
max_width = len(header_text)
|
||||
for row in rows:
|
||||
max_width = max(max_width, len(row[i]))
|
||||
col_widths.append(max_width)
|
||||
|
||||
table_header_list = [
|
||||
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
|
||||
]
|
||||
table_header = "| " + " | ".join(table_header_list) + " |\n"
|
||||
|
||||
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
|
||||
|
||||
# Create data rows with proper alignment
|
||||
data_rows = []
|
||||
for row in rows:
|
||||
formatted_row = [
|
||||
str(val).ljust(width) for val, width in zip(row, col_widths)
|
||||
]
|
||||
data_rows.append("| " + " | ".join(formatted_row) + " |")
|
||||
|
||||
return (
|
||||
self.settings_header
|
||||
+ table_header
|
||||
+ table_separator
|
||||
+ "\n".join(data_rows)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
if not self.stats:
|
||||
return
|
||||
log_fn(self.generate_metric_table())
|
||||
self.reset()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
|
||||
@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
]:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif (
|
||||
at_target
|
||||
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "allreduce_in",
|
||||
2: "residual",
|
||||
3: "norm_out",
|
||||
4: "quant_out",
|
||||
5: "scale_out",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
|
||||
@ -75,8 +75,8 @@ def find_op_nodes(
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
|
||||
@ -92,22 +92,23 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
with set_current_vllm_config(config, check_compile=False):
|
||||
if self.pass_config.enable_noop:
|
||||
if self.pass_config.eliminate_noops:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
if self.pass_config.enable_sp:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
if self.pass_config.fuse_gemm_comms:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
if self.pass_config.fuse_allreduce_rms:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
|
||||
@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.config.utils import config
|
||||
from vllm.config.utils import config, handle_deprecated
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@ -105,18 +105,43 @@ class PassConfig:
|
||||
improper state.
|
||||
"""
|
||||
|
||||
# New flags
|
||||
fuse_norm_quant: bool = Field(default=None)
|
||||
"""Fuse the custom RMSNorm + quant ops."""
|
||||
fuse_act_quant: bool = Field(default=None)
|
||||
"""Fuse the custom SiluMul + quant ops."""
|
||||
fuse_attn_quant: bool = Field(default=None)
|
||||
"""Fuse the custom attention + quant ops."""
|
||||
eliminate_noops: bool = Field(default=None)
|
||||
"""Eliminate no-op ops."""
|
||||
enable_sp: bool = Field(default=None)
|
||||
"""Enable sequence parallelism."""
|
||||
fuse_gemm_comms: bool = Field(default=None)
|
||||
"""Enable async TP."""
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
|
||||
# Deprecated flags
|
||||
enable_fusion: bool = Field(default=None)
|
||||
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
|
||||
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
|
||||
"""
|
||||
enable_attn_fusion: bool = Field(default=None)
|
||||
"""Whether to enable the custom attention+quant fusion pass."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_noop: bool = Field(default=None)
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_sequence_parallelism: bool = Field(default=None)
|
||||
"""Whether to enable sequence parallelism."""
|
||||
"""Deprecated in: v0.12.0. Use enable_sp instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_async_tp: bool = Field(default=None)
|
||||
"""Whether to enable async TP."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_fi_allreduce_fusion: bool = Field(default=None)
|
||||
"""Whether to enable flashinfer allreduce fusion."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
vllm should use flashinfer fused allreduce. Specified as a
|
||||
@ -136,7 +161,7 @@ class PassConfig:
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
|
||||
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
@ -174,6 +199,13 @@ class PassConfig:
|
||||
return InductorPass.hash_dict(asdict(self))
|
||||
|
||||
@field_validator(
|
||||
"fuse_norm_quant",
|
||||
"fuse_act_quant",
|
||||
"fuse_attn_quant",
|
||||
"eliminate_noops",
|
||||
"enable_sp",
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"enable_fusion",
|
||||
"enable_attn_fusion",
|
||||
"enable_noop",
|
||||
@ -190,18 +222,71 @@ class PassConfig:
|
||||
return handler(value)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.enable_noop:
|
||||
if self.enable_fusion:
|
||||
# Handle deprecation and defaults
|
||||
|
||||
# Map old flags to new flags and issue warnings
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fusion",
|
||||
["fuse_norm_quant", "fuse_act_quant"],
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_attn_fusion",
|
||||
"fuse_attn_quant",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_sequence_parallelism",
|
||||
"enable_sp",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_async_tp",
|
||||
"fuse_gemm_comms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fi_allreduce_fusion",
|
||||
"fuse_allreduce_rms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_noop",
|
||||
"eliminate_noops",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
# Force old flags to None to ensure they are not used
|
||||
self.enable_fusion = None
|
||||
self.enable_attn_fusion = None
|
||||
self.enable_noop = None
|
||||
self.enable_sequence_parallelism = None
|
||||
self.enable_async_tp = None
|
||||
self.enable_fi_allreduce_fusion = None
|
||||
|
||||
if not self.eliminate_noops:
|
||||
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm/SiluMul + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_attn_fusion:
|
||||
if self.fuse_attn_quant:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Attention + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_fi_allreduce_fusion:
|
||||
if self.fuse_allreduce_rms:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
@ -873,7 +958,7 @@ class CompilationConfig:
|
||||
self.set_splitting_ops_for_inductor_graph_partition()
|
||||
return
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
# here use_inductor_graph_partition is False
|
||||
self.set_splitting_ops_for_attn_fusion()
|
||||
return
|
||||
@ -915,12 +1000,12 @@ class CompilationConfig:
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
|
||||
def set_splitting_ops_for_attn_fusion(self):
|
||||
assert self.pass_config.enable_attn_fusion
|
||||
assert self.pass_config.fuse_attn_quant
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"fuse_attn_quant is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
@ -931,8 +1016,7 @@ class CompilationConfig:
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
assert not self.splitting_ops_contain_attention(), (
|
||||
"attention ops should not be in splitting_ops "
|
||||
"when enable_attn_fusion is True"
|
||||
"attention ops should not be in splitting_ops when fuse_attn_quant is True"
|
||||
)
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
@ -1008,7 +1092,7 @@ class CompilationConfig:
|
||||
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
||||
):
|
||||
multiple_of = uniform_decode_query_len
|
||||
if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
|
||||
if tensor_parallel_size > 1 and self.pass_config.enable_sp:
|
||||
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
|
||||
if (
|
||||
multiple_of % uniform_decode_query_len != 0
|
||||
|
||||
@ -55,6 +55,10 @@ class ObservabilityConfig:
|
||||
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
|
||||
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
|
||||
|
||||
cudagraph_metrics: bool = False
|
||||
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
|
||||
dispatch modes, and their observed frequencies at every logging interval)."""
|
||||
|
||||
@cached_property
|
||||
def collect_model_forward_time(self) -> bool:
|
||||
"""Whether to collect model forward time for the request."""
|
||||
|
||||
@ -19,6 +19,10 @@ import torch
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
else:
|
||||
@ -293,3 +297,28 @@ def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, ob
|
||||
def hash_factors(items: dict[str, object]) -> str:
|
||||
"""Return a SHA-256 hex digest of the canonical items structure."""
|
||||
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
|
||||
def handle_deprecated(
|
||||
config: ConfigT,
|
||||
old_name: str,
|
||||
new_name_or_names: str | list[str],
|
||||
removal_version: str,
|
||||
) -> None:
|
||||
old_val = getattr(config, old_name)
|
||||
if old_val is None:
|
||||
return
|
||||
|
||||
if isinstance(new_name_or_names, str):
|
||||
new_names = [new_name_or_names]
|
||||
else:
|
||||
new_names = new_name_or_names
|
||||
|
||||
msg = (
|
||||
f"{old_name} is deprecated and will be removed in {removal_version}. "
|
||||
f"Use {', '.join(new_names)} instead."
|
||||
)
|
||||
logger.warning(msg)
|
||||
|
||||
for new_name in new_names:
|
||||
setattr(config, new_name, old_val)
|
||||
|
||||
@ -83,22 +83,33 @@ IS_DENSE = False
|
||||
# See https://github.com/vllm-project/vllm/issues/25689.
|
||||
|
||||
|
||||
def enable_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Returns True if RMS norm or quant FP8 is enabled."""
|
||||
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if either RMS norm or quant FP8 custom op is active;
|
||||
otherwise Inductor handles fusion."""
|
||||
|
||||
return cfg.compilation_config.is_custom_op_enabled(
|
||||
"rms_norm"
|
||||
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||
|
||||
|
||||
def enable_act_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if either SiLU+Mul or quant FP8 custom op is active;
|
||||
otherwise Inductor handles fusion."""
|
||||
return cfg.compilation_config.is_custom_op_enabled(
|
||||
"silu_and_mul"
|
||||
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||
|
||||
|
||||
OPTIMIZATION_LEVEL_00 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"enable_noop": False,
|
||||
"enable_fusion": False,
|
||||
"enable_fi_allreduce_fusion": False,
|
||||
"enable_attn_fusion": False,
|
||||
"enable_sequence_parallelism": False,
|
||||
"enable_async_tp": False,
|
||||
"eliminate_noops": False,
|
||||
"fuse_norm_quant": False,
|
||||
"fuse_act_quant": False,
|
||||
"fuse_allreduce_rms": False,
|
||||
"fuse_attn_quant": False,
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@ -107,12 +118,13 @@ OPTIMIZATION_LEVEL_00 = {
|
||||
OPTIMIZATION_LEVEL_01 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"enable_noop": True,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_fi_allreduce_fusion": False,
|
||||
"enable_attn_fusion": False,
|
||||
"enable_sequence_parallelism": False,
|
||||
"enable_async_tp": False,
|
||||
"eliminate_noops": True,
|
||||
"fuse_norm_quant": enable_norm_fusion,
|
||||
"fuse_act_quant": enable_act_fusion,
|
||||
"fuse_allreduce_rms": False,
|
||||
"fuse_attn_quant": False,
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@ -121,12 +133,13 @@ OPTIMIZATION_LEVEL_01 = {
|
||||
OPTIMIZATION_LEVEL_02 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"enable_noop": True,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_fi_allreduce_fusion": False,
|
||||
"enable_attn_fusion": IS_QUANTIZED,
|
||||
"enable_sequence_parallelism": IS_DENSE,
|
||||
"enable_async_tp": IS_DENSE,
|
||||
"eliminate_noops": True,
|
||||
"fuse_norm_quant": enable_norm_fusion,
|
||||
"fuse_act_quant": enable_act_fusion,
|
||||
"fuse_allreduce_rms": False,
|
||||
"fuse_attn_quant": IS_QUANTIZED,
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@ -135,12 +148,13 @@ OPTIMIZATION_LEVEL_02 = {
|
||||
OPTIMIZATION_LEVEL_03 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"enable_noop": True,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_fi_allreduce_fusion": False,
|
||||
"enable_attn_fusion": IS_QUANTIZED,
|
||||
"enable_sequence_parallelism": IS_DENSE,
|
||||
"enable_async_tp": IS_DENSE,
|
||||
"eliminate_noops": True,
|
||||
"fuse_norm_quant": enable_norm_fusion,
|
||||
"fuse_act_quant": enable_act_fusion,
|
||||
"fuse_allreduce_rms": False,
|
||||
"fuse_attn_quant": IS_QUANTIZED,
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@ -645,9 +659,9 @@ class VllmConfig:
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism = True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
if self.compilation_config.pass_config.fuse_gemm_comms:
|
||||
self.compilation_config.pass_config.enable_sp = True
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
if "-rms_norm" in self.compilation_config.custom_ops:
|
||||
logger.warning(
|
||||
"RMS norm force disabled, sequence parallelism might break"
|
||||
@ -797,7 +811,7 @@ class VllmConfig:
|
||||
# Do this after all the updates to compilation_config.mode
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
# With pipeline parallelism or dynamo partitioning,
|
||||
# native rms norm tracing errors due to incorrect residual shape.
|
||||
# Use custom rms norm to unblock. In the future,
|
||||
@ -1062,7 +1076,7 @@ class VllmConfig:
|
||||
|
||||
if (
|
||||
self.parallel_config.tensor_parallel_size > 1
|
||||
and self.compilation_config.pass_config.enable_sequence_parallelism
|
||||
and self.compilation_config.pass_config.enable_sp
|
||||
):
|
||||
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
|
||||
cudagraph_capture_sizes
|
||||
|
||||
@ -322,9 +322,6 @@ async def transfer_layer(
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
# A buffer to hold the expert weights in one layer during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
|
||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
|
||||
@ -518,6 +518,7 @@ class EngineArgs:
|
||||
kv_cache_metrics_sample: float = get_field(
|
||||
ObservabilityConfig, "kv_cache_metrics_sample"
|
||||
)
|
||||
cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
|
||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
|
||||
|
||||
@ -1021,6 +1022,10 @@ class EngineArgs:
|
||||
"--kv-cache-metrics-sample",
|
||||
**observability_kwargs["kv_cache_metrics_sample"],
|
||||
)
|
||||
observability_group.add_argument(
|
||||
"--cudagraph-metrics",
|
||||
**observability_kwargs["cudagraph_metrics"],
|
||||
)
|
||||
|
||||
# Scheduler arguments
|
||||
scheduler_kwargs = get_kwargs(SchedulerConfig)
|
||||
@ -1698,6 +1703,7 @@ class EngineArgs:
|
||||
collect_detailed_traces=self.collect_detailed_traces,
|
||||
kv_cache_metrics=self.kv_cache_metrics,
|
||||
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
|
||||
cudagraph_metrics=self.cudagraph_metrics,
|
||||
)
|
||||
|
||||
# Compilation config overrides
|
||||
|
||||
@ -118,6 +118,7 @@ async def init_app(
|
||||
)
|
||||
)
|
||||
app.state.engine_client = engine
|
||||
app.state.args = args
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@ -109,6 +109,10 @@ def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser
|
||||
help=(
|
||||
"API key for OpenAI services. If provided, this api key "
|
||||
"will overwrite the api key obtained through environment variables."
|
||||
" It is important to note that this option only applies to the "
|
||||
"OpenAI-compatible API endpoints and NOT other endpoints that may "
|
||||
"be present in the server. See the security guide in the vLLM docs "
|
||||
"for more details."
|
||||
),
|
||||
)
|
||||
return parser
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.parser.responses_parser import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponseRawMessageAndToken,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
||||
@ -148,6 +149,8 @@ def _create_json_parse_error_messages(
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
"""This is a context that cannot handle MCP tool calls"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
self.num_prompt_tokens = 0
|
||||
@ -158,6 +161,9 @@ class SimpleContext(ConversationContext):
|
||||
# not implemented yet for SimpleContext
|
||||
self.all_turn_metrics = []
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
if not isinstance(output, RequestOutput):
|
||||
@ -166,6 +172,22 @@ class SimpleContext(ConversationContext):
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
if len(self.input_messages) == 0:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
self.input_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output.outputs[0].text,
|
||||
tokens=output.outputs[0].token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
@ -20,21 +20,15 @@ from http import HTTPStatus
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
import regex as re
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders, State
|
||||
from starlette.routing import Mount
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
StreamingResponsesResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponseVariant,
|
||||
TranslationRequest,
|
||||
@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import (
|
||||
OpenAIServingModels,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
ScalingMiddleware,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
||||
from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import decorate_logs, set_ulimit
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args(
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PrometheusResponse(Response):
|
||||
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
# instead of the default "application/json" which is incorrect.
|
||||
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
||||
Instrumentator(
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
"/health",
|
||||
"/load",
|
||||
"/ping",
|
||||
"/version",
|
||||
"/server_info",
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def base(request: Request) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(request)
|
||||
@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
try:
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
except EngineDeadError:
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
@router.get("/load")
|
||||
async def get_server_load_metrics(request: Request):
|
||||
# This endpoint returns the current server load metrics.
|
||||
@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request):
|
||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_generation(
|
||||
raw_request: Request,
|
||||
wait_for_inflight_requests: bool = Query(False),
|
||||
clear_cache: bool = Query(True),
|
||||
) -> JSONResponse:
|
||||
"""Pause generation requests to allow weight updates.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||
requests to finish before pausing. When ``False`` (default),
|
||||
aborts any in-flight requests immediately.
|
||||
clear_cache: Whether to clear KV/prefix caches after draining.
|
||||
"""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.pause_generation(
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
clear_cache=clear_cache,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={"status": "paused"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
return JSONResponse(
|
||||
content={"error": str(err)},
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to pause generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to pause generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||
"""Resume generation after a pause."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.resume_generation()
|
||||
return JSONResponse(
|
||||
content={"status": "resumed"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to resume generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to resume generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/is_paused")
|
||||
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||
"""Return the current pause status."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
paused = await engine.is_paused()
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to fetch pause status")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to fetch pause status: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
return JSONResponse(content={"is_paused": paused})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
except NotImplementedError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/detokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
except OverflowError as e:
|
||||
raise RequestValidationError(errors=[str(e)]) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
def maybe_register_tokenizer_info_endpoint(args):
|
||||
"""Conditionally register the tokenizer info endpoint if enabled."""
|
||||
if getattr(args, "enable_tokenizer_info_endpoint", False):
|
||||
|
||||
@router.get("/tokenizer_info")
|
||||
async def get_tokenizer_info(raw_request: Request):
|
||||
"""Get comprehensive tokenizer information."""
|
||||
result = await tokenization(raw_request).get_tokenizer_info()
|
||||
return JSONResponse(
|
||||
content=result.model_dump(),
|
||||
status_code=result.error.code
|
||||
if isinstance(result, ErrorResponse)
|
||||
else 200,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
handler = models(raw_request)
|
||||
@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
await engine_client(raw_request).reset_mm_cache()
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/sleep")
|
||||
async def sleep(raw_request: Request):
|
||||
# get POST params
|
||||
level = raw_request.query_params.get("level", "1")
|
||||
await engine_client(raw_request).sleep(int(level))
|
||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/wake_up")
|
||||
async def wake_up(raw_request: Request):
|
||||
tags = raw_request.query_params.getlist("tags")
|
||||
if tags == []:
|
||||
# set to None to wake up all tags if no tags are provided
|
||||
tags = None
|
||||
logger.info("wake up the engine with tags: %s", tags)
|
||||
await engine_client(raw_request).wake_up(tags)
|
||||
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.get("/is_sleeping")
|
||||
async def is_sleeping(raw_request: Request):
|
||||
logger.info("check whether the engine is sleeping")
|
||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
@router.post("/collective_rpc")
|
||||
async def collective_rpc(raw_request: Request):
|
||||
try:
|
||||
@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
return Response(status_code=200)
|
||||
response: list[Any] = []
|
||||
for result in results:
|
||||
if result is None or isinstance(result, (dict, list)):
|
||||
if result is None or isinstance(result, dict | list):
|
||||
response.append(result)
|
||||
else:
|
||||
response.append(str(result))
|
||||
return JSONResponse(content={"results": response})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": dict},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def scale_elastic_ep(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
|
||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||
|
||||
if new_data_parallel_size is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size is required"
|
||||
)
|
||||
|
||||
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size must be a positive integer"
|
||||
)
|
||||
|
||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="drain_timeout must be a positive integer"
|
||||
)
|
||||
|
||||
# Set scaling flag to prevent new requests
|
||||
global _scaling_elastic_ep
|
||||
_scaling_elastic_ep = True
|
||||
client = engine_client(raw_request)
|
||||
try:
|
||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
||||
}
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408,
|
||||
detail="Scale failed due to request drain timeout "
|
||||
f"after {drain_timeout} seconds",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Scale failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||
finally:
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
@router.post("/is_scaling_elastic_ep")
|
||||
async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def generate(request: GenerateRequest, raw_request: Request):
|
||||
handler = generate_tokens(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support generate tokens API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.serve_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, GenerateResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
logger.warning_once(
|
||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
@ -1176,41 +809,6 @@ class XRequestIdMiddleware:
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
# Global variable to track scaling state
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
class ScalingMiddleware:
|
||||
"""
|
||||
Middleware that checks if the model is currently scaling and
|
||||
returns a 503 Service Unavailable response if it is.
|
||||
|
||||
This middleware applies to all HTTP requests and prevents
|
||||
processing when the model is in a scaling state.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] != "http":
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Check global scaling state
|
||||
global _scaling_elastic_ep
|
||||
if _scaling_elastic_ep:
|
||||
# Return 503 Service Unavailable response
|
||||
response = JSONResponse(
|
||||
content={
|
||||
"error": "The model is currently scaling. Please try again later."
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
return response(scope, receive, send)
|
||||
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
|
||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||
"""Extract content from a streaming response chunk."""
|
||||
try:
|
||||
@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.state.args = args
|
||||
from vllm.entrypoints.serve import register_vllm_serve_api_routers
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!"
|
||||
)
|
||||
from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes
|
||||
|
||||
register_dynamic_lora_routes(router)
|
||||
register_vllm_serve_api_routers(app)
|
||||
|
||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||
|
||||
@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
app.root_path = args.root_path
|
||||
|
||||
mount_metrics(app)
|
||||
|
||||
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||
|
||||
register_pooling_api_routers(app)
|
||||
@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
|
||||
app = sagemaker_standards.bootstrap(app)
|
||||
# Optional endpoints
|
||||
if args.tokens_only:
|
||||
|
||||
@app.post("/abort_requests")
|
||||
async def abort_requests(raw_request: Request):
|
||||
"""
|
||||
Abort one or more requests. To be used in a
|
||||
Disaggregated Everything setup.
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
request_ids = body.get("request_ids")
|
||||
if request_ids is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'request_ids' in request body",
|
||||
)
|
||||
# Abort requests in background
|
||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||
return Response(status_code=200)
|
||||
|
||||
return app
|
||||
|
||||
@ -1515,7 +1081,7 @@ async def init_app_state(
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
state.vllm_config = vllm_config
|
||||
|
||||
state.args = args
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
@ -1839,7 +1405,6 @@ async def run_server_worker(
|
||||
args,
|
||||
client_config=client_config,
|
||||
) as engine_client:
|
||||
maybe_register_tokenizer_info_endpoint(args)
|
||||
app = build_app(args)
|
||||
|
||||
await init_app_state(engine_client, app.state, args)
|
||||
|
||||
@ -1598,6 +1598,20 @@ def serialize_messages(msgs):
|
||||
return [serialize_message(msg) for msg in msgs] if msgs else None
|
||||
|
||||
|
||||
class ResponseRawMessageAndToken(OpenAIBaseModel):
|
||||
"""Class to show the raw message.
|
||||
If message / tokens diverge, tokens is the source of truth"""
|
||||
|
||||
message: str
|
||||
tokens: list[int]
|
||||
type: Literal["raw_message_tokens"] = "raw_message_tokens"
|
||||
|
||||
|
||||
ResponseInputOutputMessage: TypeAlias = (
|
||||
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
|
||||
)
|
||||
|
||||
|
||||
class ResponsesResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
@ -1631,8 +1645,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
# These are populated when enable_response_messages is set to True
|
||||
# NOTE: custom serialization is needed
|
||||
# see serialize_input_messages and serialize_output_messages
|
||||
input_messages: list[ChatCompletionMessageParam] | None = None
|
||||
output_messages: list[ChatCompletionMessageParam] | None = None
|
||||
input_messages: ResponseInputOutputMessage | None = None
|
||||
output_messages: ResponseInputOutputMessage | None = None
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
# NOTE: openAI harmony doesn't serialize TextContent properly,
|
||||
@ -1658,8 +1672,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
output: list[ResponseOutputItem],
|
||||
status: ResponseStatus,
|
||||
usage: ResponseUsage | None = None,
|
||||
input_messages: list[ChatCompletionMessageParam] | None = None,
|
||||
output_messages: list[ChatCompletionMessageParam] | None = None,
|
||||
input_messages: ResponseInputOutputMessage | None = None,
|
||||
output_messages: ResponseInputOutputMessage | None = None,
|
||||
) -> "ResponsesResponse":
|
||||
incomplete_details: IncompleteDetails | None = None
|
||||
if status == "incomplete":
|
||||
|
||||
@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
ResponsesRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
|
||||
@ -86,6 +86,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseInputOutputMessage,
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
ResponsesRequest,
|
||||
@ -629,8 +630,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# "completed" is implemented as the "catch-all" for now.
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
input_messages = None
|
||||
output_messages = None
|
||||
input_messages: ResponseInputOutputMessage | None = None
|
||||
output_messages: ResponseInputOutputMessage | None = None
|
||||
if self.use_harmony:
|
||||
assert isinstance(context, HarmonyContext)
|
||||
output = self._make_response_output_items_with_harmony(context)
|
||||
@ -670,12 +671,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
output = self._make_response_output_items(request, final_output, tokenizer)
|
||||
|
||||
# TODO: context for non-gptoss models doesn't use messages
|
||||
# so we can't get them out yet
|
||||
if request.enable_response_messages:
|
||||
raise NotImplementedError(
|
||||
"enable_response_messages is currently only supported for gpt-oss"
|
||||
)
|
||||
input_messages = context.input_messages
|
||||
output_messages = context.output_messages
|
||||
|
||||
# Calculate usage.
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_tool_output_tokens = 0
|
||||
|
||||
@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
|
||||
completion,
|
||||
create_chat_completion,
|
||||
create_completion,
|
||||
health,
|
||||
validate_json_request,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
|
||||
score,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
|
||||
60
vllm/entrypoints/serve/__init__.py
Normal file
60
vllm/entrypoints/serve/__init__.py
Normal file
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def register_vllm_serve_api_routers(app: FastAPI):
|
||||
from vllm.entrypoints.serve.lora.api_router import (
|
||||
attach_router as attach_lora_router,
|
||||
)
|
||||
|
||||
attach_lora_router(app)
|
||||
from vllm.entrypoints.serve.elastic_ep.api_router import (
|
||||
attach_router as attach_elastic_ep_router,
|
||||
)
|
||||
|
||||
attach_elastic_ep_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.profile.api_router import (
|
||||
attach_router as attach_profile_router,
|
||||
)
|
||||
|
||||
attach_profile_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.sleep.api_router import (
|
||||
attach_router as attach_sleep_router,
|
||||
)
|
||||
|
||||
attach_sleep_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.tokenize.api_router import (
|
||||
attach_router as attach_tokenize_router,
|
||||
)
|
||||
|
||||
attach_tokenize_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.disagg.api_router import (
|
||||
attach_router as attach_disagg_router,
|
||||
)
|
||||
|
||||
attach_disagg_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.rlhf.api_router import (
|
||||
attach_router as attach_rlhf_router,
|
||||
)
|
||||
|
||||
attach_rlhf_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.instrumentator.metrics import (
|
||||
attach_router as attach_metrics_router,
|
||||
)
|
||||
|
||||
attach_metrics_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.instrumentator.health import (
|
||||
attach_router as attach_health_router,
|
||||
)
|
||||
|
||||
attach_health_router(app)
|
||||
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
110
vllm/entrypoints/serve/disagg/api_router.py
Normal file
110
vllm/entrypoints/serve/disagg/api_router.py
Normal file
@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.serving import (
|
||||
ServingTokens,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def generate(request: GenerateRequest, raw_request: Request):
|
||||
handler = generate_tokens(raw_request)
|
||||
if handler is None:
|
||||
return tokenization(raw_request).create_error_response(
|
||||
message="The model does not support generate tokens API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.serve_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, GenerateResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if getattr(app.state.args, "tokens_only", False):
|
||||
|
||||
@router.post("/abort_requests")
|
||||
async def abort_requests(raw_request: Request):
|
||||
"""
|
||||
Abort one or more requests. To be used in a
|
||||
Disaggregated Everything setup.
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
request_ids = body.get("request_ids")
|
||||
if request_ids is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'request_ids' in request body",
|
||||
)
|
||||
# Abort requests in background
|
||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||
return Response(status_code=200)
|
||||
|
||||
app.include_router(router)
|
||||
90
vllm/entrypoints/serve/disagg/protocol.py
Normal file
90
vllm/entrypoints/serve/disagg/protocol.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProbs,
|
||||
Logprob,
|
||||
SamplingParams,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
token_ids: list[int]
|
||||
"""The token ids to generate text from."""
|
||||
|
||||
# features: MultiModalFeatureSpec
|
||||
# TODO (NickLucche): implement once Renderer work is completed
|
||||
features: str | None = None
|
||||
"""The processed MM inputs for the model."""
|
||||
|
||||
sampling_params: SamplingParams
|
||||
"""The sampling parameters for the model."""
|
||||
|
||||
model: str | None = None
|
||||
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
|
||||
class GenerateResponseChoice(BaseModel):
|
||||
index: int
|
||||
logprobs: ChatCompletionLogProbs | None = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: str | None = "stop"
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
choices: list[GenerateResponseChoice]
|
||||
|
||||
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
get_scaling_elastic_ep,
|
||||
set_scaling_elastic_ep,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": dict},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def scale_elastic_ep(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
|
||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||
|
||||
if new_data_parallel_size is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size is required"
|
||||
)
|
||||
|
||||
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="new_data_parallel_size must be a positive integer",
|
||||
)
|
||||
|
||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="drain_timeout must be a positive integer"
|
||||
)
|
||||
|
||||
# Set scaling flag to prevent new requests
|
||||
set_scaling_elastic_ep(True)
|
||||
client = engine_client(raw_request)
|
||||
try:
|
||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
||||
}
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408,
|
||||
detail="Scale failed due to request drain timeout "
|
||||
f"after {drain_timeout} seconds",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Scale failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||
finally:
|
||||
set_scaling_elastic_ep(False)
|
||||
|
||||
|
||||
@router.post("/is_scaling_elastic_ep")
|
||||
async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
# Global variable to track scaling state
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
def get_scaling_elastic_ep():
|
||||
return _scaling_elastic_ep
|
||||
|
||||
|
||||
def set_scaling_elastic_ep(value):
|
||||
global _scaling_elastic_ep
|
||||
_scaling_elastic_ep = value
|
||||
|
||||
|
||||
class ScalingMiddleware:
|
||||
"""
|
||||
Middleware that checks if the model is currently scaling and
|
||||
returns a 503 Service Unavailable response if it is.
|
||||
|
||||
This middleware applies to all HTTP requests and prevents
|
||||
processing when the model is in a scaling state.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] != "http":
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Check global scaling state
|
||||
if get_scaling_elastic_ep():
|
||||
# Return 503 Service Unavailable response
|
||||
response = JSONResponse(
|
||||
content={
|
||||
"error": "The model is currently scaling. Please try again later."
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
return response(scope, receive, send)
|
||||
|
||||
return self.app(scope, receive, send)
|
||||
0
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
0
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
33
vllm/entrypoints/serve/instrumentator/health.py
Normal file
33
vllm/entrypoints/serve/instrumentator/health.py
Normal file
@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
try:
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
except EngineDeadError:
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
def attach_router(app):
|
||||
app.include_router(router)
|
||||
46
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
46
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import re
|
||||
|
||||
import prometheus_client
|
||||
from fastapi import FastAPI, Response
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.routing import Mount
|
||||
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
|
||||
|
||||
class PrometheusResponse(Response):
|
||||
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
# instead of the default "application/json" which is incorrect.
|
||||
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
||||
Instrumentator(
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
"/health",
|
||||
"/load",
|
||||
"/ping",
|
||||
"/version",
|
||||
"/server_info",
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.api_server import models, validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def register_dynamic_lora_routes(router: APIRouter):
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
|
||||
return
|
||||
logger.warning(
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!"
|
||||
)
|
||||
|
||||
@sagemaker_standards.register_load_adapter_handler(
|
||||
request_shape={
|
||||
"lora_name": "body.name",
|
||||
@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
return router
|
||||
# register the router
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
49
vllm/entrypoints/serve/profile/api_router.py
Normal file
49
vllm/entrypoints/serve/profile/api_router.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
logger.warning_once(
|
||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
102
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
102
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_generation(
|
||||
raw_request: Request,
|
||||
wait_for_inflight_requests: bool = Query(False),
|
||||
clear_cache: bool = Query(True),
|
||||
) -> JSONResponse:
|
||||
"""Pause generation requests to allow weight updates.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||
requests to finish before pausing. When ``False`` (default),
|
||||
aborts any in-flight requests immediately.
|
||||
clear_cache: Whether to clear KV/prefix caches after draining.
|
||||
"""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.pause_generation(
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
clear_cache=clear_cache,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={"status": "paused"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
return JSONResponse(
|
||||
content={"error": str(err)},
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to pause generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to pause generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||
"""Resume generation after a pause."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.resume_generation()
|
||||
return JSONResponse(
|
||||
content={"status": "resumed"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to resume generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to resume generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/is_paused")
|
||||
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||
"""Return the current pause status."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
paused = await engine.is_paused()
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to fetch pause status")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to fetch pause status: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
return JSONResponse(content={"is_paused": paused})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
60
vllm/entrypoints/serve/sleep/api_router.py
Normal file
60
vllm/entrypoints/serve/sleep/api_router.py
Normal file
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sleep")
|
||||
async def sleep(raw_request: Request):
|
||||
# get POST params
|
||||
level = raw_request.query_params.get("level", "1")
|
||||
await engine_client(raw_request).sleep(int(level))
|
||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/wake_up")
|
||||
async def wake_up(raw_request: Request):
|
||||
tags = raw_request.query_params.getlist("tags")
|
||||
if tags == []:
|
||||
# set to None to wake up all tags if no tags are provided
|
||||
tags = None
|
||||
logger.info("wake up the engine with tags: %s", tags)
|
||||
await engine_client(raw_request).wake_up(tags)
|
||||
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/is_sleeping")
|
||||
async def is_sleeping(raw_request: Request):
|
||||
logger.info("check whether the engine is sleeping")
|
||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
logger.warning(
|
||||
"SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!"
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
118
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
118
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
except NotImplementedError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/detokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
except OverflowError as e:
|
||||
raise RequestValidationError(errors=[str(e)]) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if getattr(app.state.args, "enable_tokenizer_info_endpoint", False):
|
||||
"""Conditionally register the tokenizer info endpoint if enabled."""
|
||||
|
||||
@router.get("/tokenizer_info")
|
||||
async def get_tokenizer_info(raw_request: Request):
|
||||
"""Get comprehensive tokenizer information."""
|
||||
result = await tokenization(raw_request).get_tokenizer_info()
|
||||
return JSONResponse(
|
||||
content=result.model_dump(),
|
||||
status_code=result.error.code
|
||||
if isinstance(result, ErrorResponse)
|
||||
else 200,
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
@ -767,8 +767,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
)
|
||||
|
||||
return self.target_scheme_map[matched_target]
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
if scheme_dict.get("format") is None:
|
||||
scheme_dict["format"] = self.quant_format
|
||||
return scheme_dict
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -7,7 +7,11 @@ from enum import Enum
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
||||
from compressed_tensors.quantization import (
|
||||
ActivationOrdering,
|
||||
QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
)
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -142,10 +146,26 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
# are supported + check if the layer is being ignored.
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
format = scheme_dict.get("format")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
# group_size=None means channelwise
|
||||
group_size = weight_quant.group_size or -1
|
||||
|
||||
valid_format_and_bits = (
|
||||
weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
||||
and format == CompressionFormat.pack_quantized.value
|
||||
)
|
||||
|
||||
if not valid_format_and_bits:
|
||||
raise ValueError(
|
||||
"For Fused MoE layers, only format: ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
f" and bits: {WNA16_SUPPORTED_BITS} is supported ",
|
||||
f"but got format: {CompressionFormat.pack_quantized.value} "
|
||||
f" and bits: {weight_quant.num_bits}",
|
||||
)
|
||||
|
||||
# Prefer to use the MarlinMoE kernel when it is supported.
|
||||
if (
|
||||
not check_moe_marlin_supports_layer(layer, group_size)
|
||||
@ -161,12 +181,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
weight_quant, input_quant, layer.moe_config
|
||||
)
|
||||
else:
|
||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
weight_quant, input_quant, layer.moe_config
|
||||
)
|
||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
|
||||
@ -176,15 +196,15 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
|
||||
):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
weight_quant, input_quant, layer.moe_config
|
||||
)
|
||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
weight_quant, input_quant, layer.moe_config
|
||||
)
|
||||
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A8Int8MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
weight_quant, input_quant, layer.moe_config
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
@ -650,17 +670,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations"
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
|
||||
per_tensor = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
@ -698,11 +720,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# cutlass path
|
||||
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
|
||||
self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
|
||||
self.weight_quant, self.input_quant
|
||||
)
|
||||
self.use_cutlass = not self.block_quant and (
|
||||
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
|
||||
CompressedTensorsConfig._is_fp8_w8a8_sm90(
|
||||
self.weight_quant, self.input_quant
|
||||
)
|
||||
or self.is_fp8_w8a8_sm100
|
||||
)
|
||||
self.disable_expert_map = False
|
||||
@ -1261,16 +1285,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations"
|
||||
)
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
@ -1414,36 +1436,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs | None,
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
self.group_size = config.group_size
|
||||
self.actorder = config.actorder
|
||||
self.layer_name = layer_name
|
||||
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
|
||||
assert config.symmetric, "Only symmetric quantization is supported for MoE"
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
assert weight_quant.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE"
|
||||
)
|
||||
# Extract properties from weight_quant
|
||||
self.num_bits = weight_quant.num_bits
|
||||
self.packed_factor = 32 // weight_quant.num_bits
|
||||
self.strategy = weight_quant.strategy
|
||||
self.group_size = weight_quant.group_size
|
||||
self.actorder = weight_quant.actorder
|
||||
|
||||
if not (
|
||||
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS
|
||||
):
|
||||
raise ValueError(
|
||||
"For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}",
|
||||
)
|
||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
|
||||
self.use_marlin = True
|
||||
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -1812,35 +1825,26 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs | None,
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
# Extract properties from weight_quant
|
||||
self.num_bits = weight_quant.num_bits
|
||||
self.packed_factor = 32 // weight_quant.num_bits
|
||||
self.strategy = weight_quant.strategy
|
||||
# channelwise is not supported by this kernel
|
||||
assert config.strategy == "group"
|
||||
self.group_size = config.group_size
|
||||
assert weight_quant.strategy == "group"
|
||||
self.group_size = weight_quant.group_size
|
||||
# grouped actorder isn't supported by this kernel
|
||||
assert config.actorder != "group"
|
||||
assert config.symmetric, "Only symmetric quantization is supported for MoE"
|
||||
|
||||
if not (
|
||||
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS
|
||||
):
|
||||
raise ValueError(
|
||||
"For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}",
|
||||
)
|
||||
assert weight_quant.actorder != "group"
|
||||
assert weight_quant.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE"
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -2065,28 +2069,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.has_bias = self.moe.has_bias
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
|
||||
# Validate scheme: weights=W4 (channel or group),
|
||||
# activations=dynamic TOKEN (A8)
|
||||
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
aq = self.quant_config.target_scheme_map["Linear"].get("input_activations")
|
||||
|
||||
# Must be dynamic per-token activations
|
||||
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
|
||||
if (
|
||||
input_quant.strategy != QuantizationStrategy.TOKEN
|
||||
or not input_quant.dynamic
|
||||
):
|
||||
raise ValueError(
|
||||
"W4A8-int MoE needs dynamic per-token activation quantization."
|
||||
)
|
||||
|
||||
# Weight can be channel-wise (group_size=None) or group-wise
|
||||
self.group_size = wq.group_size if (wq.group_size is not None) else -1
|
||||
if wq.num_bits != 4:
|
||||
self.group_size = (
|
||||
weight_quant.group_size if (weight_quant.group_size is not None) else -1
|
||||
)
|
||||
if weight_quant.num_bits != 4:
|
||||
raise ValueError("This method only supports 4-bit weights (num_bits=4).")
|
||||
|
||||
# CPU only
|
||||
|
||||
@ -921,7 +921,17 @@ def gguf_quant_weights_iterator(
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
if weight_type.name not in ("F32", "BF16", "F16"):
|
||||
name = name.replace("weight", "qweight")
|
||||
param = torch.tensor(weight)
|
||||
if weight_type.name == "BF16" and tensor.data.dtype == np.uint8:
|
||||
# BF16 is currently the only "quantization" type that isn't
|
||||
# actually quantized but is read as a raw byte tensor.
|
||||
# Reinterpret as `torch.bfloat16` tensor.
|
||||
weight = weight.view(np.uint16)
|
||||
if reader.byte_order == "S":
|
||||
# GGUF endianness != system endianness
|
||||
weight = weight.byteswap()
|
||||
param = torch.tensor(weight).view(torch.bfloat16)
|
||||
else:
|
||||
param = torch.tensor(weight)
|
||||
yield name, param
|
||||
|
||||
|
||||
|
||||
@ -785,6 +785,7 @@ class HunYuanVLForConditionalGeneration(
|
||||
SupportsQuant,
|
||||
SupportsXDRoPE,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
multimodal_cpu_fields = {"image_grid_thw"}
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
|
||||
@ -21,6 +21,8 @@ class MediaWithBytes(Generic[_T]):
|
||||
|
||||
The wrapper delegates attribute access to the underlying media object,
|
||||
making it behave transparently like the wrapped type (e.g., PIL.Image).
|
||||
|
||||
NOTE: Currently, this wrapper is used only for the image modality.
|
||||
"""
|
||||
|
||||
media: _T
|
||||
|
||||
@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from .base import MediaWithBytes
|
||||
from .processing import MultiModalHashes
|
||||
|
||||
else:
|
||||
@ -59,7 +60,7 @@ Represents a single audio
|
||||
item, which can be passed to a HuggingFace `AudioProcessor`.
|
||||
"""
|
||||
|
||||
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
|
||||
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
|
||||
"""
|
||||
A `transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace `ImageProcessor`.
|
||||
|
||||
@ -134,11 +134,17 @@ class EmbeddingItems(
|
||||
or a list of embedding tensors (one per item).
|
||||
"""
|
||||
|
||||
def _unwrap(
|
||||
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""Extract media from wrapper if present."""
|
||||
return item.media if isinstance(item, MediaWithBytes) else item
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> torch.Tensor:
|
||||
return self.data[index]
|
||||
return self._unwrap(self.data[index])
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
@ -478,7 +484,7 @@ class MultiModalDataParser:
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (
|
||||
isinstance(data, PILImage.Image)
|
||||
isinstance(data, (PILImage.Image, MediaWithBytes))
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
and data.ndim == 3
|
||||
):
|
||||
|
||||
@ -67,8 +67,9 @@ class MediaConnector:
|
||||
to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video":{"num_frames":40}}'`
|
||||
connection: HTTP connection client to download media contents.
|
||||
allowed_local_media_path: A local directory to load media files
|
||||
from.
|
||||
allowed_local_media_path: A local directory to load media files from.
|
||||
allowed_media_domains: If set, only media URLs that belong to this
|
||||
domain can be used for multi-modal inputs.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -123,16 +124,16 @@ class MediaConnector:
|
||||
"Cannot load local files without `--allowed-local-media-path`."
|
||||
)
|
||||
|
||||
filepath = Path(url2pathname(url_spec.path))
|
||||
filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path` {allowed_local_media_path}."
|
||||
f"of `--allowed-local-media-path {allowed_local_media_path}`."
|
||||
)
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
|
||||
if (
|
||||
self.allowed_media_domains
|
||||
and url_spec.hostname not in self.allowed_media_domains
|
||||
@ -489,9 +490,16 @@ def fetch_audio(
|
||||
Args:
|
||||
audio_url: URL of the audio file to fetch.
|
||||
audio_io_kwargs: Additional kwargs passed to handle audio IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_audio(audio_url)
|
||||
|
||||
|
||||
@ -503,9 +511,16 @@ def fetch_image(
|
||||
Args:
|
||||
image_url: URL of the image file to fetch.
|
||||
image_io_kwargs: Additional kwargs passed to handle image IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_image(image_url)
|
||||
|
||||
|
||||
@ -517,7 +532,14 @@ def fetch_video(
|
||||
Args:
|
||||
video_url: URL of the video file to fetch.
|
||||
video_io_kwargs: Additional kwargs passed to handle video IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_video(video_url)
|
||||
|
||||
@ -267,7 +267,7 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
|
||||
return frames, metadata
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
|
||||
@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin):
|
||||
|
||||
attention_mask = input_ids.ne(self.pad_id)
|
||||
text_inputs["attention_mask"] = attention_mask
|
||||
text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
|
||||
text_inputs["imgs_pos"] = [self.get_imgs_pos(e) for e in input_ids]
|
||||
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
|
||||
|
||||
return_tensors = kwargs.pop("return_tensors", None)
|
||||
|
||||
@ -7,6 +7,7 @@ from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorMetadata,
|
||||
@ -1037,6 +1038,7 @@ class Scheduler(SchedulerInterface):
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||
kv_connector_output = model_runner_output.kv_connector_output
|
||||
cudagraph_stats = model_runner_output.cudagraph_stats
|
||||
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: SpecDecodingStats | None = None
|
||||
@ -1219,7 +1221,9 @@ class Scheduler(SchedulerInterface):
|
||||
finished_req_ids.clear()
|
||||
|
||||
if (
|
||||
stats := self.make_stats(spec_decoding_stats, kv_connector_stats)
|
||||
stats := self.make_stats(
|
||||
spec_decoding_stats, kv_connector_stats, cudagraph_stats
|
||||
)
|
||||
) is not None:
|
||||
# Return stats to only one of the front-ends.
|
||||
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
|
||||
@ -1420,6 +1424,7 @@ class Scheduler(SchedulerInterface):
|
||||
self,
|
||||
spec_decoding_stats: SpecDecodingStats | None = None,
|
||||
kv_connector_stats: KVConnectorStats | None = None,
|
||||
cudagraph_stats: CUDAGraphStat | None = None,
|
||||
) -> SchedulerStats | None:
|
||||
if not self.log_stats:
|
||||
return None
|
||||
@ -1444,6 +1449,7 @@ class Scheduler(SchedulerInterface):
|
||||
kv_cache_eviction_events=eviction_events,
|
||||
spec_decoding_stats=spec_stats,
|
||||
kv_connector_stats=connector_stats_payload,
|
||||
cudagraph_stats=cudagraph_stats,
|
||||
)
|
||||
|
||||
def make_spec_decoding_stats(
|
||||
|
||||
@ -10,6 +10,7 @@ from typing import TypeAlias
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphLogging
|
||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorLogging,
|
||||
@ -106,6 +107,12 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
kv_transfer_config = self.vllm_config.kv_transfer_config
|
||||
self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
|
||||
self.cudagraph_logging = None
|
||||
if self.vllm_config.observability_config.cudagraph_metrics:
|
||||
self.cudagraph_logging = CUDAGraphLogging(
|
||||
self.vllm_config.compilation_config.cudagraph_mode,
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes,
|
||||
)
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
self.engine_is_idle = False
|
||||
@ -161,6 +168,11 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
|
||||
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
||||
self.kv_connector_logging.observe(kv_connector_stats)
|
||||
if (
|
||||
self.cudagraph_logging is not None
|
||||
and scheduler_stats.cudagraph_stats is not None
|
||||
):
|
||||
self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
|
||||
if not self.aggregated:
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
if mm_cache_stats:
|
||||
@ -240,6 +252,8 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
self.kv_connector_logging.log(log_fn=log_fn)
|
||||
if self.cudagraph_logging is not None:
|
||||
self.cudagraph_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
|
||||
@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -183,6 +184,8 @@ class SchedulerStats:
|
||||
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||
running_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
cudagraph_stats: CUDAGraphStat | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestStateStats:
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -169,6 +170,9 @@ class ModelRunnerOutput:
|
||||
# req_id -> num_nans_in_logits
|
||||
num_nans_in_logits: dict[str, int] | None = None
|
||||
|
||||
# information related to cudagraph execution
|
||||
cudagraph_stats: CUDAGraphStat | None = None
|
||||
|
||||
|
||||
# ModelRunnerOutput wrapper for async scheduling.
|
||||
class AsyncModelRunnerOutput(ABC):
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.attention.backends.abstract import (
|
||||
)
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
@ -257,6 +257,7 @@ class ExecuteModelState(NamedTuple):
|
||||
sample_hidden_states: torch.Tensor
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
ec_connector_output: ECConnectorOutput | None
|
||||
cudagraph_stats: CUDAGraphStat | None
|
||||
|
||||
|
||||
class GPUModelRunner(
|
||||
@ -2417,10 +2418,7 @@ class GPUModelRunner(
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if (
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism
|
||||
and tp_size > 1
|
||||
):
|
||||
if self.compilation_config.pass_config.enable_sp and tp_size > 1:
|
||||
return round_up(num_scheduled_tokens, tp_size)
|
||||
return num_scheduled_tokens
|
||||
|
||||
@ -2758,7 +2756,11 @@ class GPUModelRunner(
|
||||
force_uniform_decode: bool | None = None,
|
||||
force_has_lora: bool | None = None,
|
||||
) -> tuple[
|
||||
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
|
||||
CUDAGraphMode,
|
||||
BatchDescriptor,
|
||||
UBatchSlices | None,
|
||||
torch.Tensor | None,
|
||||
CUDAGraphStat | None,
|
||||
]:
|
||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||
uniform_decode = (
|
||||
@ -2823,7 +2825,22 @@ class GPUModelRunner(
|
||||
# num_tokens_across_dp will no-longer be valid
|
||||
assert batch_descriptor.num_tokens == num_tokens_padded
|
||||
|
||||
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
|
||||
cudagraph_stats = None
|
||||
if self.vllm_config.observability_config.cudagraph_metrics:
|
||||
cudagraph_stats = CUDAGraphStat(
|
||||
num_unpadded_tokens=num_tokens,
|
||||
num_padded_tokens=batch_descriptor.num_tokens,
|
||||
num_paddings=batch_descriptor.num_tokens - num_tokens,
|
||||
runtime_mode=str(cudagraph_mode),
|
||||
)
|
||||
|
||||
return (
|
||||
cudagraph_mode,
|
||||
batch_descriptor,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
cudagraph_stats,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -2921,6 +2938,7 @@ class GPUModelRunner(
|
||||
batch_desc,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
cudagraph_stats,
|
||||
) = self._determine_batch_execution_and_padding(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_reqs=num_reqs,
|
||||
@ -3070,6 +3088,7 @@ class GPUModelRunner(
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
ec_connector_output,
|
||||
cudagraph_stats,
|
||||
)
|
||||
self.kv_connector_output = kv_connector_output
|
||||
return None
|
||||
@ -3105,6 +3124,7 @@ class GPUModelRunner(
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
ec_connector_output,
|
||||
cudagraph_stats,
|
||||
) = self.execute_model_state
|
||||
# Clear ephemeral state.
|
||||
self.execute_model_state = None
|
||||
@ -3220,6 +3240,7 @@ class GPUModelRunner(
|
||||
if self.supports_mm_inputs
|
||||
else None,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
cudagraph_stats=cudagraph_stats,
|
||||
)
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
@ -3940,7 +3961,7 @@ class GPUModelRunner(
|
||||
|
||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||
|
||||
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
|
||||
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_reqs=num_reqs,
|
||||
|
||||
@ -552,7 +552,7 @@ class Worker(WorkerBase):
|
||||
|
||||
if (
|
||||
parallel_config.pipeline_parallel_size > 1
|
||||
and compilation_config.pass_config.enable_sequence_parallelism
|
||||
and compilation_config.pass_config.enable_sp
|
||||
and forward_pass
|
||||
):
|
||||
# currently only supported by V1 GPUModelRunner
|
||||
@ -564,7 +564,7 @@ class Worker(WorkerBase):
|
||||
# TODO(lucas): This is pretty gross; ideally we should only ever call
|
||||
# `_determine_batch_execution_and_padding` once (will get called again
|
||||
# in `execute_model`) but this requires a larger refactor of PP.
|
||||
_, batch_desc, _, _ = (
|
||||
_, batch_desc, _, _, _ = (
|
||||
self.model_runner._determine_batch_execution_and_padding(
|
||||
num_tokens=num_scheduled_tokens,
|
||||
num_reqs=len(num_scheduled_tokens_np),
|
||||
|
||||
@ -342,7 +342,7 @@ def is_residual_scattered_for_sp(
|
||||
partition), SP is always applied
|
||||
- Otherwise, SP is only applied for specific shapes in compile_sizes
|
||||
"""
|
||||
if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
if not vllm_config.compilation_config.pass_config.enable_sp:
|
||||
return False
|
||||
|
||||
tp = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user