mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 04:39:09 +08:00
Merge branch 'main' into model_arch_cfg
Signed-off-by: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com>
This commit is contained in:
commit
f7344c971c
@ -291,6 +291,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
Arguments:
|
||||
--version <version> : version string for the current build (e.g., commit hash)
|
||||
--wheel-dir <wheel_directory> : directory containing wheel files (default to be same as `version`)
|
||||
--current-objects <path_to_json> : path to JSON file containing current S3 objects listing in this version directory
|
||||
--output-dir <output_directory> : directory to store generated index files
|
||||
--alias-to-default <alias_variant_name> : (optional) alias variant name for the default variant
|
||||
@ -318,6 +319,12 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="Directory to store generated index files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wheel-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory containing wheel files (default to be same as `version`)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alias-to-default",
|
||||
type=str,
|
||||
@ -372,7 +379,7 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}")
|
||||
|
||||
# keep only "official" files for a non-nightly version (specifed by cli args)
|
||||
# keep only "official" files for a non-nightly version (specified by cli args)
|
||||
PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$")
|
||||
if PY_VERSION_RE.match(version):
|
||||
# upload-wheels.sh ensures no "dev" is in args.version
|
||||
@ -384,9 +391,10 @@ if __name__ == "__main__":
|
||||
print("Nightly version detected, keeping all wheel files.")
|
||||
|
||||
# Generate index and metadata, assuming wheels and indices are stored as:
|
||||
# s3://vllm-wheels/{version}/<wheel files>
|
||||
# s3://vllm-wheels/{wheel_dir}/<wheel files>
|
||||
# s3://vllm-wheels/<anything>/<index files>
|
||||
wheel_base_dir = Path(output_dir).parent / version
|
||||
wheel_dir = args.wheel_dir or version
|
||||
wheel_base_dir = Path(output_dir).parent / wheel_dir.strip().rstrip("/")
|
||||
index_base_dir = Path(output_dir)
|
||||
|
||||
generate_index_and_metadata(
|
||||
|
||||
@ -102,6 +102,7 @@ if [[ "$version" != *"dev"* ]]; then
|
||||
echo "Re-generating indices for /$pure_version/"
|
||||
rm -rf "$INDICES_OUTPUT_DIR/*"
|
||||
mkdir -p "$INDICES_OUTPUT_DIR"
|
||||
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg
|
||||
# wheel-dir is overridden to be the commit directory, so that the indices point to the correct wheel path
|
||||
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --wheel-dir "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/"
|
||||
fi
|
||||
|
||||
@ -349,7 +349,9 @@ steps:
|
||||
- label: V1 Test e2e + engine # 65min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_4
|
||||
# The test uses 4 GPUs, but we schedule it on 8-GPU machines for stability.
|
||||
# See discussion here: https://github.com/vllm-project/vllm/pull/31040
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -1254,13 +1256,13 @@ steps:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --nnodes=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --nnodes=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 68min
|
||||
timeout_in_minutes: 90
|
||||
|
||||
@ -1109,13 +1109,13 @@ steps:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --nnodes=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --nnodes=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
- python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 68min
|
||||
timeout_in_minutes: 90
|
||||
|
||||
@ -171,7 +171,7 @@ steps:
|
||||
- tests/distributed/
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
commands:
|
||||
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:0bec63fa317e1fbd62e19b0fc31c43c81bf89077 "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --nnodes=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --nnodes=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code"
|
||||
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:0bec63fa317e1fbd62e19b0fc31c43c81bf89077 "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code"
|
||||
|
||||
- label: Distributed NixlConnector PD accuracy (4 GPUs)
|
||||
timeout_in_minutes: 30
|
||||
|
||||
@ -104,7 +104,6 @@ def run_benchmark_with_batch_invariant(
|
||||
random.seed(seed)
|
||||
|
||||
# Set environment variables
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
if batch_invariant:
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
else:
|
||||
@ -140,6 +139,7 @@ def run_benchmark_with_batch_invariant(
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=tp_size,
|
||||
attention_config={"backend": backend},
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
init_time = time.perf_counter() - start_init
|
||||
|
||||
@ -107,7 +107,8 @@ __global__ void fusedQKNormRopeKernel(
|
||||
void const* k_weight_void, // RMSNorm weights for key
|
||||
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
|
||||
int64_t const* position_ids, // Position IDs for RoPE
|
||||
int const num_tokens // Number of tokens
|
||||
int const num_tokens, // Number of tokens
|
||||
int const rotary_dim // Dimension for RoPE
|
||||
) {
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
||||
@ -227,56 +228,59 @@ __global__ void fusedQKNormRopeKernel(
|
||||
|
||||
// Calculate cache pointer for this position - similar to
|
||||
// pos_encoding_kernels.cu
|
||||
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
|
||||
int const embed_dim = head_dim / 2;
|
||||
T_cache const* cache_ptr = cos_sin_cache + pos_id * rotary_dim;
|
||||
int const embed_dim = rotary_dim / 2;
|
||||
T_cache const* cos_ptr = cache_ptr;
|
||||
T_cache const* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
if constexpr (interleave) {
|
||||
// Perform interleaving. Use pre-computed cos/sin values.
|
||||
int const rotary_lanes = rotary_dim / numElemsPerThread; // rotary range
|
||||
if (laneId < rotary_lanes) {
|
||||
if constexpr (interleave) {
|
||||
// Perform interleaving. Use pre-computed cos/sin values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||
int const idx0 = 2 * i;
|
||||
int const idx1 = 2 * i + 1;
|
||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||
int const idx0 = 2 * i;
|
||||
int const idx1 = 2 * i + 1;
|
||||
// Global dimension index in the head
|
||||
int const dim_idx = laneId * numElemsPerThread + idx0;
|
||||
|
||||
float const val0 = elements[idx0];
|
||||
float const val1 = elements[idx1];
|
||||
float const val0 = elements[idx0];
|
||||
float const val1 = elements[idx1];
|
||||
|
||||
int const dim_idx = laneId * numElemsPerThread + idx0;
|
||||
int const half_dim = dim_idx / 2;
|
||||
float const cos_val =
|
||||
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float const sin_val =
|
||||
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
int const half_dim = dim_idx / 2;
|
||||
float const cos_val =
|
||||
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float const sin_val =
|
||||
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[idx0] = val0 * cos_val - val1 * sin_val;
|
||||
elements[idx1] = val0 * sin_val + val1 * cos_val;
|
||||
}
|
||||
} else {
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
// Get the data from the other half of the warp. Use pre-computed cos/sin
|
||||
// values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
|
||||
if (laneId < 16) {
|
||||
elements2[i] = -elements2[i];
|
||||
elements[idx0] = val0 * cos_val - val1 * sin_val;
|
||||
elements[idx1] = val0 * sin_val + val1 * cos_val;
|
||||
}
|
||||
} else {
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
int pairOffset = (rotary_dim / 2) / numElemsPerThread;
|
||||
// Get the data from the other half of the warp. Use pre-computed
|
||||
// cos/sin values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], pairOffset);
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
dim_idx = (dim_idx * 2) % head_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
// Use pre-computed cos/sin from cache
|
||||
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
if (laneId < pairOffset) {
|
||||
elements2[i] = -elements2[i];
|
||||
}
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
|
||||
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
|
||||
dim_idx = (dim_idx * 2) % rotary_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
|
||||
}
|
||||
// __shfl_xor_sync does not provide memfence. Need to sync again.
|
||||
__syncwarp();
|
||||
}
|
||||
// __shfl_xor_sync does not provide memfence. Need to sync again.
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Store.
|
||||
{
|
||||
vec_T vec;
|
||||
@ -312,10 +316,10 @@ template <typename scalar_t_in, typename scalar_t_cache>
|
||||
void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
int const num_heads_q, int const num_heads_k,
|
||||
int const num_heads_v, int const head_dim,
|
||||
float const eps, void const* q_weight,
|
||||
void const* k_weight, void const* cos_sin_cache,
|
||||
bool const interleave, int64_t const* position_ids,
|
||||
cudaStream_t stream) {
|
||||
int const rotary_dim, float const eps,
|
||||
void const* q_weight, void const* k_weight,
|
||||
void const* cos_sin_cache, bool const interleave,
|
||||
int64_t const* position_ids, cudaStream_t stream) {
|
||||
constexpr int blockSize = 256;
|
||||
|
||||
int const warpsPerBlock = blockSize / 32;
|
||||
@ -332,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
|
||||
});
|
||||
break;
|
||||
case 128:
|
||||
@ -340,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
|
||||
});
|
||||
break;
|
||||
case 256:
|
||||
@ -348,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
@ -392,8 +396,11 @@ void fused_qk_norm_rope(
|
||||
"Query weights size must match head dimension");
|
||||
TORCH_CHECK(k_weight.size(0) == head_dim,
|
||||
"Key weights size must match head dimension");
|
||||
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
|
||||
"Cos/sin cache dimension must match head_dim");
|
||||
|
||||
TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even");
|
||||
TORCH_CHECK(cos_sin_cache.size(1) <= head_dim,
|
||||
"rotary_dim must be less than or equal to head_dim");
|
||||
|
||||
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
|
||||
qkv.scalar_type() == k_weight.scalar_type(),
|
||||
"qkv, q_weight and k_weight must have the same dtype");
|
||||
@ -419,7 +426,8 @@ void fused_qk_norm_rope(
|
||||
qkv.data_ptr(), static_cast<int>(num_tokens),
|
||||
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
|
||||
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
|
||||
static_cast<int>(cos_sin_cache.size(1)), static_cast<float>(eps),
|
||||
q_weight.data_ptr(), k_weight.data_ptr(),
|
||||
cos_sin_cache.data_ptr(), !is_neox,
|
||||
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
|
||||
stream);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete
|
||||
ARG TRITON_BRANCH="a272dfa8"
|
||||
ARG TRITON_BRANCH="57c693b6"
|
||||
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
||||
ARG PYTORCH_BRANCH="89075173"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
@ -162,4 +162,4 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
|
||||
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
|
||||
@ -2,7 +2,7 @@ FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||
add-apt-repository -y ppa:kobuk-team/intel-graphics
|
||||
add-apt-repository -y ppa:kobuk-team/intel-graphics-staging
|
||||
|
||||
RUN apt clean && apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends --fix-missing \
|
||||
@ -47,6 +47,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir \
|
||||
-r requirements/xpu.txt
|
||||
|
||||
# arctic-inference is built from source which needs torch-xpu properly installed
|
||||
# used for suffix method speculative decoding
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir arctic-inference==0.1.1
|
||||
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"
|
||||
|
||||
COPY . .
|
||||
|
||||
@ -139,18 +139,18 @@ token data.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
```
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="query" width="70%" }
|
||||
</figure>
|
||||
<p align="center">
|
||||
<img src="../assets/design/paged_attention/query.png" alt="query" width="70%" />
|
||||
</p>
|
||||
|
||||
Each thread defines its own `q_ptr` which points to the assigned
|
||||
query token data on global memory. For example, if `VEC_SIZE` is 4
|
||||
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
|
||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="q_vecs" width="70%" }
|
||||
</figure>
|
||||
<p align="center">
|
||||
<img src="../assets/design/paged_attention/q_vecs.png" alt="q_vecs" width="70%" />
|
||||
</p>
|
||||
|
||||
```cpp
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
@ -187,9 +187,9 @@ key token at different iterations. As shown above, that `k_ptr`
|
||||
points to key token data based on `k_cache` at assigned block,
|
||||
assigned head and assigned token.
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="key" width="70%" }
|
||||
</figure>
|
||||
<p align="center">
|
||||
<img src="../assets/design/paged_attention/key.png" alt="key" width="70%" />
|
||||
</p>
|
||||
|
||||
The diagram above illustrates the memory layout for key data. It
|
||||
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
|
||||
@ -202,9 +202,9 @@ iterations. Inside each rectangle, there are a total 32 vecs (128
|
||||
elements for one token) that will be processed by 2 threads (one
|
||||
thread group) separately.
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="k_vecs" width="70%" }
|
||||
</figure>
|
||||
<p align="center">
|
||||
<img src="../assets/design/paged_attention/k_vecs.png" alt="k_vecs" width="70%" />
|
||||
</p>
|
||||
|
||||
```cpp
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||
@ -361,17 +361,17 @@ later steps. Now, it should store the normalized softmax result of
|
||||
|
||||
## Value
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="value" width="70%" }
|
||||
</figure>
|
||||
<p align="center">
|
||||
<img src="../assets/design/paged_attention/value.png" alt="value" width="70%" />
|
||||
</p>
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="logits_vec" width="50%" }
|
||||
</figure>
|
||||
<p align="center">
|
||||
<img src="../assets/design/paged_attention/logits_vec.png" alt="logits_vec" width="50%" />
|
||||
</p>
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="v_vec" width="70%" }
|
||||
</figure>
|
||||
<p align="center">
|
||||
<img src="../assets/design/paged_attention/v_vec.png" alt="v_vec" width="70%" />
|
||||
</p>
|
||||
|
||||
Now we need to retrieve the value data and perform dot multiplication
|
||||
with `logits`. Unlike query and key, there is no thread group
|
||||
|
||||
@ -64,7 +64,7 @@ th:not(:first-child) {
|
||||
| [CP](../configuration/optimization.md#chunked-prefill) | [❌](https://github.com/vllm-project/vllm/issues/2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [APC](automatic_prefix_caching.md) | [❌](https://github.com/vllm-project/vllm/issues/3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26963) |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ |
|
||||
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/26970) |
|
||||
| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
|
||||
@ -8,6 +8,16 @@ We recommend installing the library with:
|
||||
pip install nvidia-modelopt
|
||||
```
|
||||
|
||||
## Supported ModelOpt checkpoint formats
|
||||
|
||||
vLLM detects ModelOpt checkpoints via `hf_quant_config.json` and supports the
|
||||
following `quantization.quant_algo` values:
|
||||
|
||||
- `FP8`: per-tensor weight scale (+ optional static activation scale).
|
||||
- `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization.
|
||||
- `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks).
|
||||
- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`).
|
||||
|
||||
## Quantizing HuggingFace Models with PTQ
|
||||
|
||||
You can quantize HuggingFace models using the example scripts provided in the Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory.
|
||||
@ -80,3 +90,24 @@ The quantized checkpoint can then be deployed with vLLM. As an example, the foll
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
## Running the OpenAI-compatible server
|
||||
|
||||
To serve a local ModelOpt checkpoint via the OpenAI-compatible API:
|
||||
|
||||
```bash
|
||||
vllm serve <path_to_exported_checkpoint> \
|
||||
--quantization modelopt \
|
||||
--host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
## Testing (local checkpoints)
|
||||
|
||||
vLLM's ModelOpt unit tests are gated by local checkpoint paths and are skipped
|
||||
by default in CI. To run the tests locally:
|
||||
|
||||
```bash
|
||||
export VLLM_TEST_MODELOPT_FP8_PC_PT_MODEL_PATH=<path_to_fp8_pc_pt_checkpoint>
|
||||
export VLLM_TEST_MODELOPT_FP8_PB_WO_MODEL_PATH=<path_to_fp8_pb_wo_checkpoint>
|
||||
pytest -q tests/quantization/test_modelopt.py
|
||||
```
|
||||
|
||||
@ -17,6 +17,16 @@ The E4M3 format offers higher precision compared to E5M2. However, due to its sm
|
||||
|
||||
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
|
||||
|
||||
### How FP8 KV Cache Works
|
||||
|
||||
The FP8 KV cache implementation follows this workflow:
|
||||
|
||||
1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache
|
||||
2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16)
|
||||
3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor
|
||||
|
||||
This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations.
|
||||
|
||||
### Performance Impact
|
||||
|
||||
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
|
||||
|
||||
@ -28,3 +28,4 @@ The backends below live **outside** the main `vllm` repository and follow the
|
||||
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
||||
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
|
||||
| Sophgo TPU | N/A, install from source | <https://github.com/sophgo/vllm-tpu> |
|
||||
| Apple Silicon (Metal) | N/A, install from source | <https://github.com/vllm-project/vllm-metal> |
|
||||
|
||||
@ -4,6 +4,9 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must
|
||||
|
||||
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
|
||||
|
||||
!!! tip "GPU-Accelerated Inference with vLLM-Metal"
|
||||
For GPU-accelerated inference on Apple Silicon using Metal, check out [vllm-metal](https://github.com/vllm-project/vllm-metal), a community-maintained hardware plugin that uses MLX as the compute backend.
|
||||
|
||||
# --8<-- [end:installation]
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
|
||||
@ -418,7 +418,7 @@ th {
|
||||
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | ︎| ✅︎ |
|
||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
|
||||
| `MiniMaxM2ForCausalLM` | MiniMax-M2, MiniMax-M2.1 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
|
||||
| `MistralForCausalLM` | Ministral-3, Mistral, Mistral-Instruct | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `MistralLarge3ForCausalLM` | Mistral-Large-3-675B-Base-2512, Mistral-Large-3-675B-Instruct-2512 | `mistralai/Mistral-Large-3-675B-Base-2512`, `mistralai/Mistral-Large-3-675B-Instruct-2512`, etc. | ✅︎ | ✅︎ |
|
||||
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -16,7 +16,7 @@ To run inference on a single or multiple GPUs, use `VLLM` class from `langchain`
|
||||
from langchain_community.llms import VLLM
|
||||
|
||||
llm = VLLM(
|
||||
model="mosaicml/mpt-7b",
|
||||
model="Qwen/Qwen3-4B",
|
||||
trust_remote_code=True, # mandatory for hf models
|
||||
max_new_tokens=128,
|
||||
top_k=10,
|
||||
|
||||
@ -14,19 +14,19 @@ Multi-node:
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
-dp=2 \
|
||||
-tp=2 \
|
||||
--nnodes=2 \
|
||||
--node-rank=0 \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
--dp-num-nodes=2 \
|
||||
--dp-node-rank=0 \
|
||||
--dp-master-addr=10.99.48.128 \
|
||||
--dp-master-port=13345
|
||||
Node 1:
|
||||
python examples/offline_inference/data_parallel.py \
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
-dp=2 \
|
||||
-tp=2 \
|
||||
--nnodes=2 \
|
||||
--node-rank=1 \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
--dp-num-nodes=2 \
|
||||
--dp-node-rank=1 \
|
||||
--dp-master-addr=10.99.48.128 \
|
||||
--dp-master-port=13345
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -48,7 +48,31 @@ def create_parser():
|
||||
enable_expert_parallel=True,
|
||||
)
|
||||
|
||||
# Add timeout (not in EngineArgs)
|
||||
# Add DP-specific args (separate from engine args to avoid conflicts)
|
||||
parser.add_argument(
|
||||
"--dp-num-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes for data parallel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node for data parallel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address for DP coordination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port for DP coordination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
@ -132,26 +156,26 @@ if __name__ == "__main__":
|
||||
parser = create_parser()
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
# Extract DP-specific args
|
||||
# Extract DP-specific args (pop to remove from engine_args)
|
||||
dp_size = args.pop("data_parallel_size")
|
||||
nnodes = args.get("nnodes", 1)
|
||||
node_rank = args.get("node_rank", 0)
|
||||
master_addr = args.get("master_addr", "")
|
||||
master_port = args.get("master_port", 0)
|
||||
dp_num_nodes = args.pop("dp_num_nodes")
|
||||
dp_node_rank = args.pop("dp_node_rank")
|
||||
dp_master_addr = args.pop("dp_master_addr")
|
||||
dp_master_port = args.pop("dp_master_port")
|
||||
timeout = args.pop("timeout")
|
||||
|
||||
# Remaining args are engine args
|
||||
engine_args = args
|
||||
|
||||
if nnodes == 1:
|
||||
if dp_num_nodes == 1:
|
||||
dp_master_ip = "127.0.0.1"
|
||||
dp_master_port = get_open_port()
|
||||
dp_master_port_val = get_open_port()
|
||||
else:
|
||||
dp_master_ip = master_addr
|
||||
dp_master_port = master_port
|
||||
dp_master_ip = dp_master_addr
|
||||
dp_master_port_val = dp_master_port
|
||||
|
||||
assert dp_size % nnodes == 0, "dp_size should be divisible by nnodes"
|
||||
dp_per_node = dp_size // nnodes
|
||||
assert dp_size % dp_num_nodes == 0, "dp_size should be divisible by dp_num_nodes"
|
||||
dp_per_node = dp_size // dp_num_nodes
|
||||
|
||||
from multiprocessing import Process
|
||||
|
||||
@ -162,7 +186,7 @@ if __name__ == "__main__":
|
||||
|
||||
procs = []
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
|
||||
range(dp_node_rank * dp_per_node, (dp_node_rank + 1) * dp_per_node)
|
||||
):
|
||||
proc = Process(
|
||||
target=main,
|
||||
@ -171,7 +195,7 @@ if __name__ == "__main__":
|
||||
local_dp_rank,
|
||||
global_dp_rank,
|
||||
dp_master_ip,
|
||||
dp_master_port,
|
||||
dp_master_port_val,
|
||||
engine_args,
|
||||
),
|
||||
)
|
||||
|
||||
@ -38,6 +38,8 @@ Encoder engines should be launched with the following flags:
|
||||
|
||||
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
|
||||
|
||||
- `--convert "mm_encoder_only"` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must implement the `get_language_model_spec` interface.**
|
||||
|
||||
## Local media inputs
|
||||
|
||||
To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance:
|
||||
|
||||
@ -557,7 +557,8 @@ def test_rms_group_quant(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
|
||||
@ -77,6 +77,7 @@ def test_dynamic_shapes_compilation(
|
||||
"evaluate_guards": evaluate_guards,
|
||||
},
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
|
||||
output = model.generate(prompt)
|
||||
|
||||
@ -313,7 +313,7 @@ async def test_chat_streaming_input_audio(
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
{"type": "text", "text": "What's a short title for this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
@ -13,6 +13,7 @@ DTYPES = [torch.bfloat16, torch.float16]
|
||||
IS_NEOX = [True, False]
|
||||
EPS_VALUES = [1e-5, 1e-6]
|
||||
SEEDS = [13]
|
||||
PARTIAL_ROPE = [True, False]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@ -52,6 +53,7 @@ def _apply_qk_norm_rope(
|
||||
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
||||
@pytest.mark.parametrize("eps", EPS_VALUES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25])
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
device: str,
|
||||
@ -59,6 +61,7 @@ def test_fused_qk_norm_rope_matches_reference(
|
||||
is_neox: bool,
|
||||
eps: float,
|
||||
seed: int,
|
||||
rotary_ratio: float,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
@ -76,10 +79,10 @@ def test_fused_qk_norm_rope_matches_reference(
|
||||
k_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
q_weight = q_norm.weight.data
|
||||
k_weight = k_norm.weight.data
|
||||
|
||||
rotary_dim = int(head_dim * rotary_ratio)
|
||||
rope = RotaryEmbedding(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000.0,
|
||||
is_neox_style=is_neox,
|
||||
|
||||
@ -258,16 +258,16 @@ class Config:
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
|
||||
# Check block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and self.quant_dtype is None:
|
||||
# Check block quantization support
|
||||
is_block_quantized = self.quant_block_shape is not None
|
||||
if is_block_quantized and self.quant_dtype is None:
|
||||
return False, "No block quantization support."
|
||||
|
||||
if is_block_quatized and not self.is_block_quant_supported():
|
||||
if is_block_quantized and not self.is_block_quant_supported():
|
||||
return False, "Mismatched block quantization support."
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
if self.needs_deep_gemm() and not is_block_quantized:
|
||||
return False, "Needs DeepGEMM but not block quantized."
|
||||
|
||||
# Check dependencies (turn into asserts?)
|
||||
|
||||
@ -60,6 +60,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
@ -487,6 +488,7 @@ def test_mixtral_moe(
|
||||
monkeypatch.setenv("MASTER_ADDR", "localhost")
|
||||
monkeypatch.setenv("MASTER_PORT", "12345")
|
||||
init_distributed_environment()
|
||||
init_workspace_manager(torch.cuda.current_device())
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
vllm_config.compilation_config.static_forward_context = dict()
|
||||
@ -533,6 +535,11 @@ def test_mixtral_moe(
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# FIXME (zyongye) fix this after we move self.kernel
|
||||
# assignment in FusedMoE.__init__
|
||||
|
||||
vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts)
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
@ -138,7 +138,7 @@ def create_batched_mm_kwargs(
|
||||
)
|
||||
|
||||
|
||||
# TODO(Isotr0py): Don't initalize model during test
|
||||
# TODO(Isotr0py): Don't initialize model during test
|
||||
@contextmanager
|
||||
def initialize_dummy_model(
|
||||
model_cls: type[nn.Module],
|
||||
|
||||
@ -215,7 +215,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
|
||||
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
|
||||
# FIXME: databricks/dbrx-instruct has been deleted
|
||||
"DbrxForCausalLM": _HfExamplesInfo(
|
||||
"databricks/dbrx-instruct", is_available_online=False
|
||||
),
|
||||
"DeciLMForCausalLM": _HfExamplesInfo(
|
||||
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
|
||||
trust_remote_code=True,
|
||||
@ -366,7 +369,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
{"tiny": "TitanML/tiny-mixtral"},
|
||||
),
|
||||
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
|
||||
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
|
||||
# FIXME: mosaicml/mpt-7b has been deleted
|
||||
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b", is_available_online=False),
|
||||
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
|
||||
"NemotronHForCausalLM": _HfExamplesInfo(
|
||||
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
|
||||
|
||||
@ -83,7 +83,7 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
current_platform.is_rocm()
|
||||
and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
|
||||
):
|
||||
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
|
||||
pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
|
||||
|
||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||
|
||||
@ -161,7 +161,7 @@ def test_compressed_tensors_w8a8_logprobs(
|
||||
current_platform.is_rocm()
|
||||
and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
|
||||
):
|
||||
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
|
||||
pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
|
||||
|
||||
if use_aiter:
|
||||
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
|
||||
@ -231,7 +231,7 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
|
||||
current_platform.is_rocm()
|
||||
and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
|
||||
):
|
||||
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
|
||||
pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
|
||||
|
||||
if use_aiter:
|
||||
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
|
||||
|
||||
@ -217,7 +217,7 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
|
||||
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
|
||||
|
||||
# Reference dynamic quantizaton
|
||||
# Reference dynamic quantization
|
||||
y = quantize_ref(x, inv_scale)
|
||||
torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ Run `pytest tests/quantization/test_modelopt.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import NoReturn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -19,6 +20,28 @@ def enable_pickle(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
def _skip(msg: str) -> NoReturn:
|
||||
pytest.skip(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def _snapshot_download_or_skip(model_id: str) -> str:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except Exception as e: # pragma: no cover
|
||||
_skip(f"huggingface_hub is required to download {model_id}: {e}")
|
||||
|
||||
try:
|
||||
return snapshot_download(
|
||||
repo_id=model_id,
|
||||
repo_type="model",
|
||||
# These checkpoints are already small; download full repo for simplicity.
|
||||
allow_patterns=["*"],
|
||||
)
|
||||
except Exception as e:
|
||||
_skip(f"Failed to download {model_id} from the HF Hub: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
@ -91,3 +114,121 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
assert output
|
||||
print(f"ModelOpt FP8 output: {output}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
|
||||
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
|
||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
|
||||
model_path = _snapshot_download_or_skip(model_id)
|
||||
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
o_proj = layer.self_attn.o_proj
|
||||
gate_up_proj = layer.mlp.gate_up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8PcPtLinearMethod,
|
||||
)
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
assert isinstance(o_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
assert isinstance(gate_up_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
assert isinstance(down_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
|
||||
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert o_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert down_proj.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
# Per-channel scales; activations are dynamically scaled per token.
|
||||
assert hasattr(qkv_proj, "weight_scale")
|
||||
assert qkv_proj.weight_scale.dtype == torch.float32
|
||||
assert qkv_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(qkv_proj, "input_scale")
|
||||
|
||||
assert hasattr(o_proj, "weight_scale")
|
||||
assert o_proj.weight_scale.dtype == torch.float32
|
||||
assert o_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(o_proj, "input_scale")
|
||||
|
||||
assert hasattr(gate_up_proj, "weight_scale")
|
||||
assert gate_up_proj.weight_scale.dtype == torch.float32
|
||||
assert gate_up_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(gate_up_proj, "input_scale")
|
||||
|
||||
assert hasattr(down_proj, "weight_scale")
|
||||
assert down_proj.weight_scale.dtype == torch.float32
|
||||
assert down_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(down_proj, "input_scale")
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
assert output
|
||||
print(f"ModelOpt FP8_PER_CHANNEL_PER_TOKEN output: {output}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
|
||||
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
|
||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
|
||||
model_path = _snapshot_download_or_skip(model_id)
|
||||
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
o_proj = layer.self_attn.o_proj
|
||||
gate_up_proj = layer.mlp.gate_up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8PbWoLinearMethod,
|
||||
)
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
assert isinstance(o_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
assert isinstance(gate_up_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
assert isinstance(down_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
|
||||
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert o_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert down_proj.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
# Block scales; should be materialized as a 2D [out_blk, in_blk] tensor.
|
||||
assert hasattr(qkv_proj, "weight_scale")
|
||||
assert qkv_proj.weight_scale.dtype == torch.float32
|
||||
assert qkv_proj.weight_scale.dim() == 2
|
||||
|
||||
assert hasattr(o_proj, "weight_scale")
|
||||
assert o_proj.weight_scale.dtype == torch.float32
|
||||
assert o_proj.weight_scale.dim() == 2
|
||||
|
||||
assert hasattr(gate_up_proj, "weight_scale")
|
||||
assert gate_up_proj.weight_scale.dtype == torch.float32
|
||||
assert gate_up_proj.weight_scale.dim() == 2
|
||||
|
||||
assert hasattr(down_proj, "weight_scale")
|
||||
assert down_proj.weight_scale.dtype == torch.float32
|
||||
assert down_proj.weight_scale.dim() == 2
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
assert output
|
||||
print(f"ModelOpt FP8_PB_WO output: {output}")
|
||||
|
||||
@ -18,25 +18,37 @@ for i in {1..5}; do
|
||||
echo "Checking metadata.json URL (attempt $i)..."
|
||||
if curl --fail "$meta_json_url" > metadata.json; then
|
||||
echo "INFO: metadata.json URL is valid."
|
||||
# check whether it is valid json by python
|
||||
# check whether it is valid json by python (printed to stdout)
|
||||
if python3 -m json.tool metadata.json; then
|
||||
echo "INFO: metadata.json is valid JSON. Proceeding with the test."
|
||||
echo "INFO: metadata.json is valid JSON. Proceeding with the check."
|
||||
# check whether there is an object in the json matching:
|
||||
# "package_name": "vllm", and "platform_tag" matches the current architecture
|
||||
# see `determine_wheel_url` in setup.py for more details
|
||||
if python3 -c "import platform as p,json as j,sys as s; d = j.load(open('metadata.json')); \
|
||||
s.exit(int(not any(o.get('package_name') == 'vllm' and p.machine() in o.get('platform_tag') \
|
||||
for o in d)))" 2>/dev/null; then
|
||||
echo "INFO: metadata.json contains a pre-compiled wheel for the current architecture."
|
||||
break
|
||||
else
|
||||
echo "WARN: metadata.json does not have a pre-compiled wheel for the current architecture."
|
||||
fi
|
||||
else
|
||||
echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
|
||||
echo "INFO: metadata.json content:"
|
||||
cat metadata.json
|
||||
exit 1
|
||||
fi
|
||||
break
|
||||
fi
|
||||
# failure handling
|
||||
# failure handling & retry logic
|
||||
if [ $i -eq 5 ]; then
|
||||
echo "ERROR: metadata.json URL is still not valid after 5 attempts."
|
||||
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit exists."
|
||||
echo "ERROR: metadata is still not available after 5 attempts."
|
||||
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit is available."
|
||||
echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes."
|
||||
echo " NOTE: If it fails, please report in #sig-ci channel."
|
||||
exit 1
|
||||
else
|
||||
echo "WARNING: metadata.json URL is not valid. Retrying in 3 minutes..."
|
||||
sleep 180
|
||||
echo "WARNING: metadata is not available. Retrying after 5 minutes..."
|
||||
sleep 300
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
@ -38,7 +38,8 @@ TOKENIZERS = [
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
# FIXME: mosaicml/mpt-7b has been deleted
|
||||
# "mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
import torch.cuda
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
|
||||
@ -14,6 +15,11 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that preprocessing errors are handled gracefully."""
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"Skipped on ROCm: this test only works with 'fork', but ROCm uses 'spawn'."
|
||||
)
|
||||
|
||||
assert not torch.cuda.is_initialized(), (
|
||||
"fork needs to be used for the engine "
|
||||
"core process and this isn't possible if cuda is already initialized"
|
||||
|
||||
@ -306,10 +306,16 @@ def test_prepare_inputs_padded():
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
|
||||
proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
)
|
||||
|
||||
# Verify num_rejected_tokens_gpu is calculated correctly
|
||||
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
|
||||
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
|
||||
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
||||
|
||||
@ -761,7 +761,7 @@ class rocm_aiter_ops:
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_fp8_enaled(cls) -> bool:
|
||||
def is_linear_fp8_enabled(cls) -> bool:
|
||||
return cls.is_linear_enabled()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -15,7 +15,7 @@ def merge_attn_states(
|
||||
output_lse: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||
# does not support FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerOption = Literal["auto", RunnerType]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward", "mm_encoder_only"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
|
||||
@ -811,6 +811,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_text += harmony_parser.last_content_delta or ""
|
||||
cur_channel = harmony_parser.current_channel
|
||||
cur_recipient = harmony_parser.current_recipient
|
||||
# handle the case where several tokens where generated at once
|
||||
# including the final token, leading to a delta in the text
|
||||
# but the current channel to be empty (start state)
|
||||
if not cur_channel and delta_text:
|
||||
cur_channel = "final"
|
||||
else:
|
||||
delta_text = output.text
|
||||
|
||||
|
||||
@ -11,9 +11,11 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
is_batch_invariant = vllm_is_batch_invariant()
|
||||
|
||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
@ -150,7 +152,8 @@ def _get_lora_b_ptr(
|
||||
@functools.lru_cache
|
||||
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
if user_defined_config_folder is not None:
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if user_defined_config_folder is not None and not is_batch_invariant:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
gpu_name = gpu_name.replace(" ", "_")
|
||||
gpu_name = gpu_name.replace("-", "_")
|
||||
@ -203,11 +206,14 @@ def get_lora_op_configs(
|
||||
# default config
|
||||
default = {}
|
||||
if op_type == "shrink":
|
||||
split_k = 64 if batch < 128 else 8
|
||||
if is_batch_invariant:
|
||||
split_k = 1
|
||||
default = {
|
||||
"block_m": 32,
|
||||
"block_n": 16,
|
||||
"block_k": 256 if batch < 128 else 32,
|
||||
"split_k": 64 if batch < 128 else 8,
|
||||
"split_k": split_k,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"group_size_m": 8,
|
||||
|
||||
@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
]
|
||||
|
||||
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||
@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||
elif (
|
||||
hidden_states.dtype == torch.float8_e4m3fn
|
||||
or hidden_states.dtype == torch.float8_e4m3fnuz
|
||||
):
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(self, defer_input_quant: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.defer_input_quant = defer_input_quant
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if self.defer_input_quant:
|
||||
return a1, None, None, None, None
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_scale,
|
||||
|
||||
@ -5,11 +5,15 @@ from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
|
||||
|
||||
class QuantMethod(IntEnum):
|
||||
@ -263,3 +267,78 @@ def rocm_aiter_fused_experts(
|
||||
a2_scale=quant_config.a2_scale,
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
)
|
||||
|
||||
def supports_expert_map(self):
|
||||
return True
|
||||
|
||||
def supports_chunking(self):
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Workspaces are managed internally by AITER.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert a1q_scale is None
|
||||
assert a2_scale is None
|
||||
assert expert_tokens_meta is None
|
||||
|
||||
result = rocm_aiter_fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
assert result.shape == output.shape
|
||||
output.copy_(result)
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@ -23,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
@ -30,9 +34,9 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts, fused_experts
|
||||
from .fused_moe import TritonExperts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
TritonExperts = None # type: ignore
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
@ -265,6 +269,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
elif current_platform.is_cuda_alike():
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(self.moe_quant_config),
|
||||
shared_experts=None,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -278,9 +289,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
@ -322,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
result = fused_experts(
|
||||
result = self.kernel(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -330,7 +339,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
@ -53,6 +53,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"GPTQLinearMethod",
|
||||
"FBGEMMFp8LinearMethod",
|
||||
"ModelOptFp8LinearMethod",
|
||||
"ModelOptFp8PcPtLinearMethod",
|
||||
"ModelOptFp8PbWoLinearMethod",
|
||||
"IPEXAWQLinearMethod",
|
||||
"IPEXGPTQLinearMethod",
|
||||
"HQQMarlinMethod",
|
||||
|
||||
@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
)
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
assert not self.is_static_input_scheme
|
||||
|
||||
@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
|
||||
DEEPGEMM = 3
|
||||
MARLIN = 4
|
||||
TRITON = 5
|
||||
AITER = 6
|
||||
|
||||
|
||||
def get_fp8_moe_backend(
|
||||
@ -189,6 +190,10 @@ def get_fp8_moe_backend(
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
|
||||
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.AITER
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
@ -414,7 +419,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# Lazy import to avoid importing triton too early.
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.moe_quant_config = config
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
|
||||
# with the changes to defer input quantization
|
||||
FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
use_dp=(self.moe.dp_size > 1),
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
Fp8MoeBackend.DEEPGEMM,
|
||||
Fp8MoeBackend.TRITON,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.AITER,
|
||||
]:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
config = self.get_fused_moe_quant_config(layer)
|
||||
assert config is not None
|
||||
self.moe_quant_config = config
|
||||
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
moe_kernel = (
|
||||
MarlinExperts(quant_config=self.moe_quant_config)
|
||||
if use_marlin
|
||||
else TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=allow_deep_gemm,
|
||||
)
|
||||
)
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(), moe_kernel
|
||||
)
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO: make defer_input_quant an attr of the AiterExperts
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
AiterExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
else:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
),
|
||||
)
|
||||
self.use_inplace = True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.rocm_aiter_moe_enabled
|
||||
self.fp8_backend == Fp8MoeBackend.AITER
|
||||
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.fp8_backend != Fp8MoeBackend.MARLIN
|
||||
) and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||
)
|
||||
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
|
||||
raise NotImplementedError(
|
||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||
)
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
# TODO(rob): convert this to MK.
|
||||
result = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# Lazy import to avoid importing triton too early.
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
# Reshuffle weights for AITER if needed.
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# Rushuffle weights for MARLIN if needed.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
|
||||
@ -51,7 +51,7 @@ class QuantFP8(CustomOp):
|
||||
self.column_major_scales = column_major_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
|
||||
@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
@ -72,9 +75,15 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_scaled_fp4_mm,
|
||||
@ -88,7 +97,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
QUANT_ALGOS = [
|
||||
# FP8 (per-tensor weight + optional static activation scale).
|
||||
"FP8",
|
||||
# FP8 per-channel weight scale + per-token activation scale.
|
||||
"FP8_PER_CHANNEL_PER_TOKEN",
|
||||
# FP8 per-block weight-only (ModelOpt may emit this as lowercase).
|
||||
"FP8_PB_WO",
|
||||
# FP4
|
||||
"NVFP4",
|
||||
]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
|
||||
@ -255,6 +273,9 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
if not quant_method:
|
||||
raise ValueError("Missing 'quant_algo' in quantization config")
|
||||
|
||||
# Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
|
||||
quant_method = str(quant_method).upper()
|
||||
|
||||
if kv_cache_quant_method is None:
|
||||
# No KV cache quantization, keep this branch just to have this comment
|
||||
pass
|
||||
@ -263,6 +284,8 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
f"kv_cache_quant_algo must be a string, got "
|
||||
f"{type(kv_cache_quant_method)}"
|
||||
)
|
||||
else:
|
||||
kv_cache_quant_method = kv_cache_quant_method.upper()
|
||||
|
||||
if not isinstance(exclude_modules, list):
|
||||
raise ValueError(
|
||||
@ -302,17 +325,34 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_method: str,
|
||||
is_checkpoint_fp8_serialized: bool,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
) -> None:
|
||||
super().__init__(exclude_modules)
|
||||
self.quant_method = quant_method
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning(
|
||||
"Detected ModelOpt fp8 checkpoint. Please note that"
|
||||
" the format is experimental and could change."
|
||||
"Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
|
||||
"that the format is experimental and could change.",
|
||||
quant_method,
|
||||
)
|
||||
|
||||
# Select LinearMethod implementation based on quant_algo.
|
||||
if self.quant_method == "FP8":
|
||||
self.LinearMethodCls = ModelOptFp8LinearMethod
|
||||
elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
|
||||
self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
|
||||
elif self.quant_method == "FP8_PB_WO":
|
||||
self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported ModelOpt FP8 quant_algo for vLLM: "
|
||||
f"{self.quant_method}. Supported: FP8 / "
|
||||
"FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
|
||||
)
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
@ -346,13 +386,13 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = quant_config.get("quant_algo", "")
|
||||
if "FP8" in quant_algo:
|
||||
quant_algo = str(quant_config.get("quant_algo", ""))
|
||||
if "FP8" in quant_algo.upper():
|
||||
return "modelopt"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = hf_quant_cfg.get("quant_algo", "")
|
||||
if isinstance(quant_algo, str) and "FP8" in quant_algo:
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
|
||||
if "FP8" in quant_algo.upper():
|
||||
return "modelopt"
|
||||
|
||||
return None
|
||||
@ -369,7 +409,12 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
) -> "ModelOptFp8Config":
|
||||
is_checkpoint_fp8_serialized = "FP8" in quant_method
|
||||
|
||||
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
|
||||
return cls(
|
||||
quant_method,
|
||||
is_checkpoint_fp8_serialized,
|
||||
kv_cache_quant_method,
|
||||
exclude_modules,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
@ -464,6 +509,203 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
"""Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.
|
||||
|
||||
Expected checkpoint structure (per Linear):
|
||||
- weight: fp8-e4m3fn, shape [out, in]
|
||||
- weight_scale: fp32, shape [out] (per-output-channel)
|
||||
- no input_scale (activations are dynamically quantized per-token)
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"FP8_PER_CHANNEL_PER_TOKEN currently only supports "
|
||||
"FP8-serialized checkpoints."
|
||||
)
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(output_size_per_partition, dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
"""Linear method for ModelOpt FP8_PB_WO checkpoints.
|
||||
|
||||
ModelOpt exports `weight_scale` as a 4D tensor:
|
||||
[out_blk, 1, in_blk, 1]
|
||||
where block size is typically 128 for both dims.
|
||||
|
||||
vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
|
||||
"""
|
||||
|
||||
_WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
block_n, block_k = self._WEIGHT_BLOCK_SIZE
|
||||
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(block_n, block_k),
|
||||
act_quant_group_shape=GroupShape(1, block_k),
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"FP8_PB_WO currently only supports FP8-serialized checkpoints."
|
||||
)
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Expose block size so the v2 weight loaders can translate offsets from
|
||||
# element-space -> block-space for BlockQuantScaleParameter.
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
block_n, block_k = self._WEIGHT_BLOCK_SIZE
|
||||
if output_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
"ModelOpt FP8_PB_WO requires out_features divisible by "
|
||||
f"{block_n}, got {output_size_per_partition}."
|
||||
)
|
||||
if input_size_per_partition % block_k != 0:
|
||||
raise ValueError(
|
||||
"ModelOpt FP8_PB_WO requires in_features divisible by "
|
||||
f"{block_k}, got {input_size_per_partition}."
|
||||
)
|
||||
|
||||
out_blks = output_size_per_partition // block_n
|
||||
in_blks = input_size_per_partition // block_k
|
||||
|
||||
# Match ModelOpt's exported shape so weight loading works without a
|
||||
# custom loader: [out_blk, 1, in_blk, 1]
|
||||
weight_scale = BlockQuantScaleParameter(
|
||||
data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
|
||||
input_dim=2,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
scale = layer.weight_scale
|
||||
if scale.dim() == 4:
|
||||
# [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
|
||||
scale = scale.squeeze(1).squeeze(-1)
|
||||
elif scale.dim() != 2:
|
||||
raise ValueError(
|
||||
"Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
|
||||
f"{tuple(scale.shape)}."
|
||||
)
|
||||
|
||||
layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for ModelOpt FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
|
||||
@ -189,7 +189,9 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
if convert_type not in ["none", "mm_encoder_only"] and supports_multimodal(
|
||||
model_cls
|
||||
):
|
||||
logger.debug_once("Detected conversion of Multi Modal model.")
|
||||
converted = try_create_mm_pooling_model_cls(model_cls)
|
||||
if converted is not None:
|
||||
@ -200,6 +202,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
|
||||
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "mm_encoder_only":
|
||||
logger.debug_once("Converting to mm encoder only model.")
|
||||
from vllm.model_executor.models.adapters import as_mm_encoder_only_model
|
||||
|
||||
model_cls = as_mm_encoder_only_model(model_cls)
|
||||
elif convert_type == "embed":
|
||||
logger.debug_once("Converting to embedding model.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
|
||||
@ -520,3 +520,64 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
method = getattr(text_config, "method", None)
|
||||
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
||||
return SEQ_CLS_LOAD_METHODS[method](model, weights)
|
||||
|
||||
|
||||
def as_mm_encoder_only_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM vl model to support mm encoder only for
|
||||
EPD encoder instances.
|
||||
"""
|
||||
if not hasattr(cls, "embed_multimodal"):
|
||||
# Submodel case: return the original class.
|
||||
return cls
|
||||
|
||||
if not hasattr(cls, "get_language_model_spec"):
|
||||
raise TypeError(f"{cls} need to implement `get_language_model_spec` method.")
|
||||
|
||||
lm_model_cls, lm_attr = cls.get_language_model_spec()
|
||||
|
||||
if lm_model_cls is None or lm_attr is None:
|
||||
raise TypeError(
|
||||
f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)"
|
||||
)
|
||||
|
||||
class DummyLM(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.make_empty_intermediate_tensors = None
|
||||
|
||||
class ModelForMMEncoderOnly(cls):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.is_mm_encoder_only_model = True
|
||||
origin_init = lm_model_cls.__init__
|
||||
try:
|
||||
lm_model_cls.__init__ = DummyLM.__init__
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
if hasattr(self, lm_attr):
|
||||
delattr(self, lm_attr)
|
||||
finally:
|
||||
lm_model_cls.__init__ = origin_init
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
origin_init_ = AutoWeightsLoader.__init__
|
||||
|
||||
def _new_init_(self, *args, **kwargs):
|
||||
origin_init_(self, *args, **kwargs)
|
||||
self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."]
|
||||
|
||||
try:
|
||||
AutoWeightsLoader.__init__ = _new_init_
|
||||
result = super().load_weights(weights)
|
||||
finally:
|
||||
AutoWeightsLoader.__init__ = origin_init_
|
||||
return result
|
||||
|
||||
return ModelForMMEncoderOnly # type: ignore
|
||||
|
||||
@ -487,7 +487,7 @@ class BagelForConditionalGeneration(
|
||||
# Split by image
|
||||
return tuple(vision_embeds)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
"""Get multimodal embeddings from input."""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@ -401,7 +401,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
# attn_1_token = 2kB -> fits ~394 tokens
|
||||
# then round up to a mulitple of 256 -> 512 tokens
|
||||
# then round up to a multiple of 256 -> 512 tokens
|
||||
# End result:
|
||||
# attn_block_size = 512
|
||||
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
||||
|
||||
@ -141,6 +141,14 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
Return the language model spec:
|
||||
(language model class, language model attr)
|
||||
"""
|
||||
return None, None
|
||||
|
||||
@overload
|
||||
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
|
||||
|
||||
@ -302,6 +310,10 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
|
||||
return getattr(model, "supports_encoder_tp_data", False)
|
||||
|
||||
|
||||
def supports_mm_encoder_only(model: type[object] | object) -> bool:
|
||||
return getattr(model, "is_mm_encoder_only_model", False)
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_pruning(
|
||||
model: type[object],
|
||||
|
||||
@ -34,7 +34,7 @@ import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, Qwen2ForCausalLM
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
@ -1567,3 +1567,11 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
connector="visual.merger.",
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
Return the language model spec:
|
||||
(language model class, language model attr)
|
||||
"""
|
||||
return Qwen2ForCausalLM, "language_model"
|
||||
|
||||
@ -323,7 +323,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
# vit pos embeding, TODO: spatial_patch_size vs patch_size
|
||||
# vit pos embedding, TODO: spatial_patch_size vs patch_size
|
||||
if self.apply_vit_abs_pos_embed:
|
||||
self.pos_embed = nn.Embedding(self.num_grid_per_side**2, self.hidden_size)
|
||||
else:
|
||||
|
||||
@ -2090,3 +2090,11 @@ class Qwen3VLForConditionalGeneration(
|
||||
connector="visual.merger",
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
Return the language model spec:
|
||||
(language model class, language model attr)
|
||||
"""
|
||||
return Qwen3LLMForCausalLM, "language_model"
|
||||
|
||||
@ -408,7 +408,7 @@ class RocmPlatform(Platform):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
|
||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
|
||||
@ -156,7 +156,9 @@ class XPUPlatform(Platform):
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
# decrease triton kernel compilation scratch space for speculative decoding
|
||||
if vllm_config.speculative_config is not None:
|
||||
os.environ["IGC_ForceOCLSIMDWidth"] = "16" # noqa: SIM112
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# Only override worker_cls if it's still the default "auto"
|
||||
|
||||
@ -104,7 +104,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
|
||||
# 3. Both BOT and EOT have been outputted.
|
||||
elif has_bot_token and has_eot_token:
|
||||
return input_ids[:bot_token_index] + input_ids[eot_token_index + 1 :]
|
||||
# 4. Only EOT has been outputted => this should not have occured for a model
|
||||
# 4. Only EOT has been outputted => this should not have occurred for a model
|
||||
# well prompted and trained.
|
||||
else:
|
||||
return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]
|
||||
|
||||
@ -138,37 +138,167 @@ class MinimaxM2ToolParser(ToolParser):
|
||||
return name_str
|
||||
|
||||
def _convert_param_value(self, value: str, param_type: str) -> Any:
|
||||
"""Convert parameter value to the correct type."""
|
||||
"""Convert parameter value to the correct type (legacy single-type version)."""
|
||||
return self._convert_param_value_with_types(value, [param_type])
|
||||
|
||||
def _extract_types_from_schema(self, schema: Any) -> list[str]:
|
||||
"""
|
||||
Extract all possible types from a JSON schema definition.
|
||||
Handles anyOf, oneOf, allOf, type arrays, and enum fields.
|
||||
|
||||
Args:
|
||||
schema: The JSON schema definition for a parameter
|
||||
|
||||
Returns:
|
||||
List of type strings (e.g., ["string", "integer", "null"])
|
||||
"""
|
||||
if schema is None:
|
||||
return ["string"]
|
||||
|
||||
if not isinstance(schema, dict):
|
||||
return ["string"]
|
||||
|
||||
types: set[str] = set()
|
||||
|
||||
# Handle direct "type" field
|
||||
if "type" in schema:
|
||||
type_value = schema["type"]
|
||||
if isinstance(type_value, str):
|
||||
types.add(type_value)
|
||||
elif isinstance(type_value, list):
|
||||
for t in type_value:
|
||||
if isinstance(t, str):
|
||||
types.add(t)
|
||||
|
||||
# Handle enum - infer types from enum values
|
||||
if "enum" in schema and isinstance(schema["enum"], list) and schema["enum"]:
|
||||
for value in schema["enum"]:
|
||||
if value is None:
|
||||
types.add("null")
|
||||
elif isinstance(value, bool):
|
||||
types.add("boolean")
|
||||
elif isinstance(value, int):
|
||||
types.add("integer")
|
||||
elif isinstance(value, float):
|
||||
types.add("number")
|
||||
elif isinstance(value, str):
|
||||
types.add("string")
|
||||
elif isinstance(value, list):
|
||||
types.add("array")
|
||||
elif isinstance(value, dict):
|
||||
types.add("object")
|
||||
|
||||
# Handle anyOf, oneOf, allOf - recursively extract types
|
||||
for choice_field in ("anyOf", "oneOf", "allOf"):
|
||||
if choice_field in schema and isinstance(schema[choice_field], list):
|
||||
for choice in schema[choice_field]:
|
||||
extracted = self._extract_types_from_schema(choice)
|
||||
types.update(extracted)
|
||||
|
||||
# If no types found, default to string
|
||||
if not types:
|
||||
return ["string"]
|
||||
|
||||
return list(types)
|
||||
|
||||
def _convert_param_value_with_types(
|
||||
self, value: str, param_types: list[str]
|
||||
) -> Any:
|
||||
"""
|
||||
Convert parameter value to the correct type based on a list of possible types.
|
||||
Tries each type in order until one succeeds.
|
||||
|
||||
Args:
|
||||
value: The string value to convert
|
||||
param_types: List of possible type strings
|
||||
|
||||
Returns:
|
||||
The converted value
|
||||
"""
|
||||
if value.lower() == "null":
|
||||
return None
|
||||
|
||||
param_type = param_type.lower()
|
||||
if param_type in ["string", "str", "text"]:
|
||||
# Normalize types
|
||||
normalized_types = [t.lower() for t in param_types]
|
||||
|
||||
# Try null first if it's in the list
|
||||
if "null" in normalized_types or value.lower() in ("null", "none", "nil"):
|
||||
return None
|
||||
|
||||
# Try each type in order of preference (most specific first, string as fallback)
|
||||
# Priority: integer > number > boolean > object > array > string
|
||||
type_priority = [
|
||||
"integer",
|
||||
"int",
|
||||
"number",
|
||||
"float",
|
||||
"boolean",
|
||||
"bool",
|
||||
"object",
|
||||
"array",
|
||||
"string",
|
||||
"str",
|
||||
"text",
|
||||
]
|
||||
|
||||
for param_type in type_priority:
|
||||
if param_type not in normalized_types:
|
||||
continue
|
||||
|
||||
if param_type in ["string", "str", "text"]:
|
||||
return value
|
||||
elif param_type in ["integer", "int"]:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif param_type in ["number", "float"]:
|
||||
try:
|
||||
val = float(value)
|
||||
return val if val != int(val) else int(val)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif param_type in ["boolean", "bool"]:
|
||||
lower_val = value.lower().strip()
|
||||
if lower_val in ["true", "1", "yes", "on"]:
|
||||
return True
|
||||
elif lower_val in ["false", "0", "no", "off"]:
|
||||
return False
|
||||
continue
|
||||
elif param_type in ["object", "array"]:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Fallback: try JSON parse, then return as string
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
elif param_type in ["integer", "int"]:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["number", "float"]:
|
||||
try:
|
||||
val = float(value)
|
||||
return val if val != int(val) else int(val)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["boolean", "bool"]:
|
||||
return value.lower() in ["true", "1"]
|
||||
elif param_type in ["object", "array"]:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
else:
|
||||
# Try JSON parse first, fallback to string
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def _get_param_types_from_config(
|
||||
self, param_name: str, param_config: dict
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get parameter types from parameter configuration.
|
||||
Handles anyOf, oneOf, allOf, and direct type definitions.
|
||||
|
||||
Args:
|
||||
param_name: The name of the parameter
|
||||
param_config: The properties dict from the tool schema
|
||||
|
||||
Returns:
|
||||
List of type strings
|
||||
"""
|
||||
if param_name not in param_config:
|
||||
return ["string"]
|
||||
|
||||
param_schema = param_config[param_name]
|
||||
if not isinstance(param_schema, dict):
|
||||
return ["string"]
|
||||
|
||||
return self._extract_types_from_schema(param_schema)
|
||||
|
||||
def _parse_single_invoke(
|
||||
self, invoke_str: str, tools: list | None
|
||||
@ -207,17 +337,11 @@ class MinimaxM2ToolParser(ToolParser):
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Get parameter type
|
||||
param_type = "string"
|
||||
if (
|
||||
param_name in param_config
|
||||
and isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]
|
||||
):
|
||||
param_type = param_config[param_name]["type"]
|
||||
# Get parameter types (supports anyOf/oneOf/allOf)
|
||||
param_type = self._get_param_types_from_config(param_name, param_config)
|
||||
|
||||
# Convert value
|
||||
param_dict[param_name] = self._convert_param_value(
|
||||
param_dict[param_name] = self._convert_param_value_with_types(
|
||||
param_value, param_type
|
||||
)
|
||||
|
||||
@ -593,7 +717,7 @@ class MinimaxM2ToolParser(ToolParser):
|
||||
# Store raw value for later processing
|
||||
self.accumulated_params[self.current_param_name] = param_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
# Get parameter configuration with anyOf support
|
||||
param_config = {}
|
||||
if self.streaming_request and self.streaming_request.tools:
|
||||
for tool in self.streaming_request.tools:
|
||||
@ -610,17 +734,12 @@ class MinimaxM2ToolParser(ToolParser):
|
||||
param_config = params["properties"]
|
||||
break
|
||||
|
||||
# Get parameter type
|
||||
param_type = "string"
|
||||
if (
|
||||
self.current_param_name in param_config
|
||||
and isinstance(param_config[self.current_param_name], dict)
|
||||
and "type" in param_config[self.current_param_name]
|
||||
):
|
||||
param_type = param_config[self.current_param_name]["type"]
|
||||
# Get parameter types (supports anyOf/oneOf/allOf)
|
||||
param_type = self._get_param_types_from_config(
|
||||
self.current_param_name, param_config
|
||||
)
|
||||
|
||||
# Convert param value to appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
converted_value = self._convert_param_value_with_types(
|
||||
param_value, param_type
|
||||
)
|
||||
|
||||
|
||||
@ -389,7 +389,7 @@ def should_use_deepgemm_for_fp8_linear(
|
||||
|
||||
# Verify DeepGEMM N/K dims requirements
|
||||
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
|
||||
# test inside kernels/quatization/test_block_fp8.py
|
||||
# test inside kernels/quantization/test_block_fp8.py
|
||||
N_MULTIPLE = 64
|
||||
K_MULTIPLE = 128
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
|
||||
@ -355,6 +355,8 @@ class MLACommonPrefillMetadata:
|
||||
max_query_len: int
|
||||
chunked_context: ChunkedContextMetadata | None = None
|
||||
query_seq_lens: torch.Tensor | None = None
|
||||
workspace_buffer: torch.Tensor | None = None
|
||||
q_data_type: torch.dtype | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -558,6 +560,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self.dcp_rank = 0
|
||||
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
|
||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
# Don't try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
@ -722,8 +725,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
@ -773,13 +776,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
||||
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
@ -794,6 +791,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
|
||||
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
@ -983,19 +982,29 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill_metadata.query_seq_lens = (
|
||||
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
|
||||
)
|
||||
prefill_metadata.workspace_buffer = self._workspace_buffer
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
dcp_tot_seq_lens_device = None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
|
||||
seq_lens_cpu = dcp_local_seq_lens_cpu
|
||||
seq_lens = dcp_local_seq_lens
|
||||
|
||||
# After DCP distribution, the maximum number of tokens for any rank is
|
||||
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
|
||||
# and I is cp_kv_cache_interleave_size.
|
||||
# This eliminates GPU->CPU sync while minimizing workspace
|
||||
# over-allocation.
|
||||
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
|
||||
max_seq_len = (
|
||||
(max_seq_len + num_partitions - 1) // num_partitions
|
||||
) * self.cp_kv_cache_interleave_size
|
||||
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
||||
seq_lens_device=seq_lens[:num_decodes],
|
||||
max_seq_len=max_seq_len,
|
||||
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
||||
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
@ -1491,12 +1500,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
|
||||
ret = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
workspace_buffer=prefill.workspace_buffer,
|
||||
seq_lens=prefill.query_seq_lens,
|
||||
max_q_len=prefill.max_query_len,
|
||||
max_kv_len=prefill.max_query_len,
|
||||
@ -1525,6 +1535,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
assert prefill.chunked_context is not None
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
|
||||
out = torch.zeros(
|
||||
q.shape[0],
|
||||
@ -1533,13 +1544,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
device=q.device,
|
||||
dtype=q.dtype,
|
||||
)
|
||||
self._workspace_buffer.fill_(0)
|
||||
prefill.workspace_buffer.fill_(0)
|
||||
|
||||
attn_out, lse = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
workspace_buffer=prefill.workspace_buffer,
|
||||
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
|
||||
max_q_len=prefill.max_query_len,
|
||||
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
|
||||
|
||||
@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
num_reqs=seq_lens_device.shape[0],
|
||||
cu_query_lens=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens_device,
|
||||
|
||||
@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
|
||||
@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
|
||||
@ -236,6 +236,7 @@ class EagleProposer:
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@ -414,6 +415,17 @@ class EagleProposer:
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[: batch_size + 1]
|
||||
).clone()
|
||||
|
||||
# In padded drafter batch, we need to adjust the sequence lengths
|
||||
# to remove the "padding" (i.e. rejected tokens).
|
||||
# Only apply this adjustment when we have rejected tokens
|
||||
# (i.e., not the first proposal).
|
||||
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
|
||||
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
|
||||
# Invalidate the CPU-side shadows to avoid H<>D sync.
|
||||
common_attn_metadata._seq_lens_cpu = None
|
||||
common_attn_metadata._num_computed_tokens_cpu = None
|
||||
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
@ -628,13 +640,14 @@ class EagleProposer:
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor,
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
device = valid_sampled_tokens_count.device
|
||||
@ -642,14 +655,17 @@ class EagleProposer:
|
||||
token_indices_to_sample = torch.empty(
|
||||
(num_reqs,), dtype=torch.int32, device=device
|
||||
)
|
||||
num_rejected_tokens_gpu = torch.empty(
|
||||
(num_reqs,), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Kernel grid: one program per request (row)
|
||||
grid = (num_reqs,)
|
||||
eagle_prepare_inputs_padded_kernel[grid](
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
valid_sampled_tokens_count,
|
||||
common_attn_metadata.query_start_loc,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
num_reqs,
|
||||
)
|
||||
|
||||
@ -674,7 +690,11 @@ class EagleProposer:
|
||||
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
return spec_common_attn_metadata, token_indices_to_sample
|
||||
return (
|
||||
spec_common_attn_metadata,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
)
|
||||
|
||||
def propose_tree(
|
||||
self,
|
||||
|
||||
@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
||||
token_indices_to_sample_ptr, # [num_reqs] (output)
|
||||
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
|
||||
num_reqs, # tl.int32
|
||||
):
|
||||
"""
|
||||
@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
|
||||
|
||||
index_to_sample = q_last_tok_idx - num_rejected_tokens
|
||||
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
|
||||
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@ -66,6 +66,7 @@ from vllm.model_executor.models.interfaces import (
|
||||
SupportsXDRoPE,
|
||||
is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_mm_encoder_only,
|
||||
supports_mrope,
|
||||
supports_multimodal_pruning,
|
||||
supports_transcription,
|
||||
@ -3533,6 +3534,7 @@ class GPUModelRunner(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
num_rejected_tokens_gpu = None
|
||||
if spec_decode_metadata is None:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
@ -3563,12 +3565,14 @@ class GPUModelRunner(
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
else:
|
||||
common_attn_metadata, token_indices_to_sample = (
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count,
|
||||
)
|
||||
(
|
||||
common_attn_metadata,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
) = self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count,
|
||||
)
|
||||
total_num_tokens = common_attn_metadata.num_actual_tokens
|
||||
# When padding the batch, token_indices is just a range
|
||||
@ -3599,6 +3603,7 @@ class GPUModelRunner(
|
||||
sampling_metadata=sampling_metadata,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embed_inputs=mm_embed_inputs,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
@ -4067,6 +4072,11 @@ class GPUModelRunner(
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||
"""
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# The current dummy run only covers LM execution, so we can skip it.
|
||||
# mm encoder dummy run may need to add in the future.
|
||||
return torch.tensor([]), torch.tensor([])
|
||||
|
||||
assert (
|
||||
cudagraph_runtime_mode is None
|
||||
or cudagraph_runtime_mode.valid_runtime_modes()
|
||||
@ -4344,6 +4354,11 @@ class GPUModelRunner(
|
||||
# The dummy hidden states may contain special values,
|
||||
# like `inf` or `nan`.
|
||||
# To avoid breaking the sampler, we use a random tensor here instead.
|
||||
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# MM Encoder only model no need to run sampler.
|
||||
return torch.tensor([])
|
||||
|
||||
hidden_states = torch.rand_like(hidden_states)
|
||||
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
@ -4472,6 +4487,10 @@ class GPUModelRunner(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> PoolerOutput:
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# MM Encoder only model not need to run pooler.
|
||||
return torch.tensor([])
|
||||
|
||||
# Find the task that has the largest output for subsequent steps
|
||||
supported_pooling_tasks = self.get_supported_pooling_tasks()
|
||||
|
||||
|
||||
@ -634,7 +634,12 @@ class Worker(WorkerBase):
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiling is not enabled.")
|
||||
raise RuntimeError(
|
||||
"Profiling is not enabled. Please set --profiler-config to enable "
|
||||
"profiling. Example: "
|
||||
"'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
|
||||
"=YOUR_DIR_PATH_TO_DUMP_TRACE'"
|
||||
)
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user