mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 06:22:21 +08:00
Merge branch 'main' into mlm-full-lora-support
This commit is contained in:
commit
94dce5c3d9
@ -1223,6 +1223,8 @@ steps:
|
|||||||
# FIXIT: find out which code initialize cuda before running the test
|
# FIXIT: find out which code initialize cuda before running the test
|
||||||
# before the fix, we need to use spawn to test it
|
# before the fix, we need to use spawn to test it
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
# Alot of these tests are on the edge of OOMing
|
||||||
|
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||||
# requires multi-GPU testing for validation.
|
# requires multi-GPU testing for validation.
|
||||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub.utils import HfHubHTTPError
|
from huggingface_hub.utils import HfHubHTTPError
|
||||||
@ -194,5 +194,8 @@ def test_get_adapter_absolute_path_huggingface_error(
|
|||||||
# Hugging Face model identifier with download error
|
# Hugging Face model identifier with download error
|
||||||
path = "org/repo"
|
path = "org/repo"
|
||||||
mock_exist.return_value = False
|
mock_exist.return_value = False
|
||||||
mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info")
|
mock_snapshot_download.side_effect = HfHubHTTPError(
|
||||||
|
"failed to query model info",
|
||||||
|
response=MagicMock(),
|
||||||
|
)
|
||||||
assert get_adapter_absolute_path(path) == path
|
assert get_adapter_absolute_path(path) == path
|
||||||
|
|||||||
@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
|
|||||||
"mm_encoder_attn_backend",
|
"mm_encoder_attn_backend",
|
||||||
[None] + current_platform.get_supported_vit_attn_backends(),
|
[None] + current_platform.get_supported_vit_attn_backends(),
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_vit_backend_functionality(
|
def test_vit_backend_functionality(
|
||||||
model_key: str,
|
model_key: str,
|
||||||
|
|||||||
@ -642,48 +642,130 @@ _OPS_REGISTERED = False
|
|||||||
|
|
||||||
|
|
||||||
class rocm_aiter_ops:
|
class rocm_aiter_ops:
|
||||||
|
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
|
||||||
|
|
||||||
|
This class centralizes the import and registration of AITER ops,
|
||||||
|
and provides a unified interface for checking if AITER is enabled.
|
||||||
|
Operations are only available on supported gfx9
|
||||||
|
architectures when aiter is installed.
|
||||||
|
|
||||||
|
The class uses environment variables to control which features are enabled,
|
||||||
|
allowing fine-grained control over which AITER optimizations are used.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
|
||||||
|
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
|
||||||
|
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
|
||||||
|
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
|
||||||
|
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
|
||||||
|
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
|
||||||
|
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
|
||||||
|
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
|
||||||
|
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
|
||||||
|
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
|
||||||
|
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
|
||||||
|
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The environment variables are assigned when the module is imported,
|
||||||
|
so you can't change the environment variables after the module is imported.
|
||||||
|
This is done out of performance consideration. Accessing environment variables
|
||||||
|
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
|
||||||
|
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
|
||||||
|
You can call the refresh_env_variables() function to reload the env variables
|
||||||
|
after monkey patching the env variables in the unit test.
|
||||||
|
|
||||||
|
Check Functions:
|
||||||
|
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
|
||||||
|
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
|
||||||
|
(3) aiter library is installed. The check function then also verifies
|
||||||
|
the corresponding environment variable is enabled.
|
||||||
|
i.e. ___
|
||||||
|
is_enabled() == current_platform.is_rocm() and | checked by
|
||||||
|
current_platform.is_on_gfx9() and | @if_aiter_supported
|
||||||
|
IS_AITER_FOUND and _______________|
|
||||||
|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
|
||||||
|
# Check if aiter is enabled before using operations
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
|
||||||
|
|
||||||
|
Operations:
|
||||||
|
- RMS normalization: rms_norm, rms_norm2d_with_add
|
||||||
|
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
|
||||||
|
- Fused MoE: fused_moe, asm_moe_tkw1
|
||||||
|
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
|
||||||
|
- MLA decode: mla_decode_fwd
|
||||||
|
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
|
||||||
|
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if the env variable is set
|
||||||
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||||
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||||
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
|
||||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
|
# TODO: Consolidate under _LINEAR_ENABLED
|
||||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||||
|
# TODO: Consolidate under _LINEAR_ENABLED
|
||||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||||
|
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
|
||||||
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||||
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||||
|
# TODO: Consolidate under _LINEAR_ENABLED
|
||||||
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def refresh_env_variables(cls):
|
||||||
|
"""
|
||||||
|
Since the environment variables are assigned when the module is imported,
|
||||||
|
This is a helper function to reload all the env variables from
|
||||||
|
the environment variables.
|
||||||
|
for example, after monkey patching the env variables in the unit test,
|
||||||
|
you can call this function to reload the env variables.
|
||||||
|
"""
|
||||||
|
cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||||
|
cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||||
|
cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||||
|
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||||
|
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||||
|
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
|
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
|
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||||
|
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||||
|
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||||
|
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||||
|
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_enabled(cls) -> bool:
|
def is_enabled(cls) -> bool:
|
||||||
"""Verifies device specs and availability of aiter main env variable."""
|
|
||||||
return cls._AITER_ENABLED
|
return cls._AITER_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_linear_enabled(cls) -> bool:
|
def is_linear_enabled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_linear_fp8_enaled(cls) -> bool:
|
def is_linear_fp8_enaled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls.is_linear_enabled()
|
return cls.is_linear_enabled()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_rmsnorm_enabled(cls) -> bool:
|
def is_rmsnorm_enabled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_fused_moe_enabled(cls) -> bool:
|
def is_fused_moe_enabled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -694,25 +776,16 @@ class rocm_aiter_ops:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_mla_enabled(cls) -> bool:
|
def is_mla_enabled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_mha_enabled(cls) -> bool:
|
def is_mha_enabled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@if_aiter_supported
|
|
||||||
def is_pa_attn_enabled(cls) -> bool:
|
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_triton_unified_attn_enabled(cls) -> bool:
|
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
|
||||||
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -937,7 +937,7 @@ class CompilationConfig:
|
|||||||
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
):
|
):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using piecewise compilation with empty splitting_ops"
|
"Using piecewise cudagraph with empty splitting_ops"
|
||||||
)
|
)
|
||||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|||||||
@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE):
|
|||||||
self._shared_experts = shared_experts
|
self._shared_experts = shared_experts
|
||||||
|
|
||||||
# Disable shared expert overlap if:
|
# Disable shared expert overlap if:
|
||||||
# - we are using eplb, because of correctness issues
|
# - we are using eplb with non-default backend, because of correctness issues
|
||||||
# - we are using flashinfer with DP, since there nothing to gain
|
# - we are using flashinfer with DP, since there nothint to gain
|
||||||
# - we are using marlin kernels
|
# - we are using marlin kernels
|
||||||
|
backend = self.moe_parallel_config.all2all_backend
|
||||||
self.use_overlapped = (
|
self.use_overlapped = (
|
||||||
use_overlapped
|
use_overlapped
|
||||||
and not (
|
and not (
|
||||||
# TODO(wentao): find the root cause and remove this condition
|
(self.enable_eplb and backend != "allgather_reducescatter")
|
||||||
self.enable_eplb
|
|
||||||
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||||
)
|
)
|
||||||
and self._shared_experts is not None
|
and self._shared_experts is not None
|
||||||
|
|||||||
@ -469,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
)
|
)
|
||||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||||
|
|
||||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
layer.w13_weight = Parameter(
|
||||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.gemm2_weights_fp4_shuffled = Parameter(
|
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||||
gemm2_weights_fp4_shuffled, requires_grad=False
|
layer.w13_weight_scale = Parameter(
|
||||||
)
|
|
||||||
layer.gemm1_scales_fp4_shuffled = Parameter(
|
|
||||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.gemm2_scales_fp4_shuffled = Parameter(
|
layer.w2_weight_scale = Parameter(
|
||||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -487,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up weights that won't be used by TRT-LLM
|
|
||||||
del layer.w2_weight
|
|
||||||
del layer.w2_weight_scale
|
|
||||||
del layer.w13_weight
|
|
||||||
del layer.w13_weight_scale
|
|
||||||
else:
|
else:
|
||||||
# swizzle weight scales
|
# swizzle weight scales
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
|||||||
@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||||
|
|
||||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
layer.w13_weight = Parameter(
|
||||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.gemm2_weights_fp4_shuffled = Parameter(
|
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||||
gemm2_weights_fp4_shuffled, requires_grad=False
|
layer.w13_weight_scale = Parameter(
|
||||||
)
|
|
||||||
layer.gemm1_scales_fp4_shuffled = Parameter(
|
|
||||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.gemm2_scales_fp4_shuffled = Parameter(
|
layer.w2_weight_scale = Parameter(
|
||||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up weights that won't be used by TRT-LLM
|
|
||||||
del layer.w2_weight
|
|
||||||
del layer.w2_weight_scale
|
|
||||||
del layer.w13_weight
|
|
||||||
del layer.w13_weight_scale
|
|
||||||
elif self.use_marlin:
|
elif self.use_marlin:
|
||||||
# Marlin processing
|
# Marlin processing
|
||||||
prepare_moe_fp4_layer_for_marlin(layer)
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
|
|||||||
@ -301,18 +301,14 @@ def flashinfer_trtllm_fp4_moe(
|
|||||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||||
torch.float8_e4m3fn
|
torch.float8_e4m3fn
|
||||||
).flatten(),
|
).flatten(),
|
||||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
gemm1_weights=layer.w13_weight.data,
|
||||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||||
torch.float8_e4m3fn
|
|
||||||
),
|
|
||||||
gemm1_bias=None,
|
gemm1_bias=None,
|
||||||
gemm1_alpha=None,
|
gemm1_alpha=None,
|
||||||
gemm1_beta=None,
|
gemm1_beta=None,
|
||||||
gemm1_clamp_limit=None,
|
gemm1_clamp_limit=None,
|
||||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
gemm2_weights=layer.w2_weight.data,
|
||||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||||
torch.float8_e4m3fn
|
|
||||||
),
|
|
||||||
gemm2_bias=None,
|
gemm2_bias=None,
|
||||||
output1_scale_scalar=layer.g1_scale_c.data,
|
output1_scale_scalar=layer.g1_scale_c.data,
|
||||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||||
@ -380,18 +376,14 @@ def flashinfer_trtllm_fp4_routed_moe(
|
|||||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||||
torch.float8_e4m3fn
|
torch.float8_e4m3fn
|
||||||
).flatten(),
|
).flatten(),
|
||||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
gemm1_weights=layer.w13_weight.data,
|
||||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||||
torch.float8_e4m3fn
|
|
||||||
),
|
|
||||||
gemm1_bias=None,
|
gemm1_bias=None,
|
||||||
gemm1_alpha=None,
|
gemm1_alpha=None,
|
||||||
gemm1_beta=None,
|
gemm1_beta=None,
|
||||||
gemm1_clamp_limit=None,
|
gemm1_clamp_limit=None,
|
||||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
gemm2_weights=layer.w2_weight.data,
|
||||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||||
torch.float8_e4m3fn
|
|
||||||
),
|
|
||||||
gemm2_bias=None,
|
gemm2_bias=None,
|
||||||
output1_scale_scalar=layer.g1_scale_c.data,
|
output1_scale_scalar=layer.g1_scale_c.data,
|
||||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||||
|
|||||||
@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
|
|||||||
"position_ids",
|
"position_ids",
|
||||||
torch.arange(config.max_position_embeddings).unsqueeze(0),
|
torch.arange(config.max_position_embeddings).unsqueeze(0),
|
||||||
)
|
)
|
||||||
self.position_embedding_type = config.position_embedding_type
|
self.position_embedding_type = getattr(
|
||||||
|
config, "position_embedding_type", "absolute"
|
||||||
|
)
|
||||||
if self.position_embedding_type != "absolute":
|
if self.position_embedding_type != "absolute":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only 'absolute' position_embedding_type" + " is supported"
|
"Only 'absolute' position_embedding_type" + " is supported"
|
||||||
|
|||||||
@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
|
|||||||
torch.arange(config.max_position_embeddings).unsqueeze(0),
|
torch.arange(config.max_position_embeddings).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.position_embedding_type = config.position_embedding_type
|
|
||||||
if self.position_embedding_type != "absolute":
|
|
||||||
raise ValueError(
|
|
||||||
"Only 'absolute' position_embedding_type" + " is supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
def _build_model(
|
def _build_model(
|
||||||
self, vllm_config: VllmConfig, prefix: str = ""
|
self, vllm_config: VllmConfig, prefix: str = ""
|
||||||
) -> BertModel | BertWithRope:
|
) -> BertModel | BertWithRope:
|
||||||
if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
|
hf_config = vllm_config.model_config.hf_config
|
||||||
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
|
kwargs = dict(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
|
||||||
|
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
|
||||||
else:
|
else:
|
||||||
return BertModel(
|
return JinaRobertaModel(**kwargs)
|
||||||
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
weights_list = list(weights)
|
weights_list = list(weights)
|
||||||
|
|||||||
@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.FloatTensor | None = None,
|
attention_mask: torch.FloatTensor | None = None,
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
batch_size, dim, num_channels = hidden_states.shape
|
batch_size, dim, num_channels = hidden_states.shape
|
||||||
@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.FloatTensor | None = None,
|
attention_mask: torch.FloatTensor | None = None,
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
|
||||||
hidden_states, attention_mask, head_mask, output_attentions
|
|
||||||
)
|
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
outputs = (attention_output,) + self_outputs[1:]
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
return outputs
|
return outputs
|
||||||
@ -339,18 +335,14 @@ class SwinStage(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_dimensions: tuple[int, int],
|
input_dimensions: tuple[int, int],
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
always_partition: bool | None = False,
|
always_partition: bool | None = False,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
height, width = input_dimensions
|
height, width = input_dimensions
|
||||||
for i, layer_module in enumerate(self.blocks):
|
for i, layer_module in enumerate(self.blocks):
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
input_dimensions,
|
input_dimensions,
|
||||||
layer_head_mask,
|
|
||||||
output_attentions,
|
output_attentions,
|
||||||
always_partition,
|
always_partition,
|
||||||
)
|
)
|
||||||
@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_dimensions: tuple[int, int],
|
input_dimensions: tuple[int, int],
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
always_partition: bool | None = False,
|
always_partition: bool | None = False,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
for i, layer_module in enumerate(self.layers):
|
for i, layer_module in enumerate(self.layers):
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
input_dimensions,
|
input_dimensions,
|
||||||
layer_head_mask,
|
|
||||||
output_attentions,
|
output_attentions,
|
||||||
always_partition,
|
always_partition,
|
||||||
)
|
)
|
||||||
@ -473,7 +461,6 @@ class SwinModel(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = None,
|
output_attentions: bool | None = None,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||||
@ -481,7 +468,6 @@ class SwinModel(nn.Module):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
input_dimensions,
|
input_dimensions,
|
||||||
head_mask=head_mask,
|
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
"""PyTorch Ultravox model."""
|
"""PyTorch Ultravox model."""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Annotated, Any, Literal, TypeAlias
|
from typing import Annotated, Any, Literal, TypeAlias
|
||||||
@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
|
|||||||
)
|
)
|
||||||
hidden_states = hidden_states + positions
|
hidden_states = hidden_states + positions
|
||||||
|
|
||||||
|
# Backward compatibility for Transformers v4 where layer_head_mask
|
||||||
|
# was a required argument for WhisperEncoderLayer.forward
|
||||||
|
kwargs = {}
|
||||||
|
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
|
||||||
|
kwargs["layer_head_mask"] = None
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
layer_head_mask=None,
|
**kwargs,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
|||||||
|
|
||||||
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
|
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
|
||||||
|
|
||||||
|
# Backward compatibility for Transformers v4 where layer_head_mask
|
||||||
|
# was a required argument for WhisperEncoderLayer.forward
|
||||||
|
kwargs = {}
|
||||||
|
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
|
||||||
|
kwargs["layer_head_mask"] = None
|
||||||
|
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
layer_head_mask=None,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
@ -124,8 +124,6 @@ def use_rocm_custom_paged_attention(
|
|||||||
alibi_slopes: torch.Tensor | None = None,
|
alibi_slopes: torch.Tensor | None = None,
|
||||||
sinks: torch.Tensor | None = None,
|
sinks: torch.Tensor | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
|
||||||
|
|
||||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||||
@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention(
|
|||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||||
and max_seq_len <= 128 * 1024
|
and max_seq_len <= 128 * 1024
|
||||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||||
and not (rocm_aiter_ops.is_pa_attn_enabled())
|
|
||||||
and sinks is None
|
and sinks is None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
|
|||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
MLACommonMetadataBuilder,
|
MLACommonMetadataBuilder,
|
||||||
|
QueryLenSupport,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
|||||||
qo_indptr: torch.Tensor | None = None
|
qo_indptr: torch.Tensor | None = None
|
||||||
# The dtype of MLA out tensor
|
# The dtype of MLA out tensor
|
||||||
attn_out_dtype: torch.dtype = torch.bfloat16
|
attn_out_dtype: torch.dtype = torch.bfloat16
|
||||||
|
# The max query output length: int
|
||||||
|
max_qo_len: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||||
@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
|||||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||||
# TODO(luka, lucas): audit this as part of:
|
# TODO(luka, lucas): audit this as part of:
|
||||||
# https://github.com/vllm-project/vllm/issues/22945
|
# https://github.com/vllm-project/vllm/issues/22945
|
||||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
max_num_reqs, dtype=torch.int32, device=device
|
max_num_reqs, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qo_indptr = torch.arange(
|
self.qo_indptr = torch.zeros(
|
||||||
0, max_num_reqs + 1, dtype=torch.int32, device=device
|
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_decode(
|
def _build_decode(
|
||||||
@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
|
max_qo_len = qo_len.max().item()
|
||||||
|
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
num_actual_pages = paged_kv_indices.size(0)
|
num_actual_pages = paged_kv_indices.size(0)
|
||||||
@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||||
|
|
||||||
|
self.qo_indptr[: 1 + num_reqs].copy_(
|
||||||
|
query_start_loc_device, non_blocking=True
|
||||||
|
)
|
||||||
|
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
|
||||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||||
qo_indptr=qo_indptr,
|
qo_indptr=qo_indptr,
|
||||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||||
|
max_qo_len=max_qo_len,
|
||||||
attn_out_dtype=self.decode_attn_out_dtype,
|
attn_out_dtype=self.decode_attn_out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
|
|
||||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||||
|
|
||||||
# max_seqlen_qo must be 1 except for MTP
|
|
||||||
# TODO: Find the best value for MTP
|
|
||||||
max_seqlen_qo = 1
|
|
||||||
rocm_aiter_ops.mla_decode_fwd(
|
rocm_aiter_ops.mla_decode_fwd(
|
||||||
q,
|
q,
|
||||||
kv_buffer,
|
kv_buffer,
|
||||||
o,
|
o,
|
||||||
self.scale,
|
self.scale,
|
||||||
attn_metadata.decode.qo_indptr,
|
attn_metadata.decode.qo_indptr,
|
||||||
max_seqlen_qo,
|
attn_metadata.decode.max_qo_len,
|
||||||
attn_metadata.decode.paged_kv_indptr,
|
attn_metadata.decode.paged_kv_indptr,
|
||||||
attn_metadata.decode.paged_kv_indices,
|
attn_metadata.decode.paged_kv_indices,
|
||||||
attn_metadata.decode.paged_kv_last_page_len,
|
attn_metadata.decode.paged_kv_last_page_len,
|
||||||
|
|||||||
@ -187,6 +187,12 @@ class Scheduler(SchedulerInterface):
|
|||||||
if self.is_encoder_decoder
|
if self.is_encoder_decoder
|
||||||
else EncoderCacheManager(cache_size=encoder_cache_size)
|
else EncoderCacheManager(cache_size=encoder_cache_size)
|
||||||
)
|
)
|
||||||
|
# For encoder-decoder models, allocate the maximum number of tokens for Cross
|
||||||
|
# Attn blocks, as for Whisper its input is always padded to the maximum length.
|
||||||
|
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
|
||||||
|
self._num_encoder_max_input_tokens = (
|
||||||
|
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config)
|
||||||
|
)
|
||||||
|
|
||||||
speculative_config = vllm_config.speculative_config
|
speculative_config = vllm_config.speculative_config
|
||||||
self.use_eagle = False
|
self.use_eagle = False
|
||||||
@ -568,17 +574,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine if we need to allocate cross-attention blocks.
|
num_encoder_tokens = (
|
||||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
self._num_encoder_max_input_tokens
|
||||||
# TODO(russellb): For Whisper, we know that the input is
|
if self.is_encoder_decoder and request.has_encoder_inputs
|
||||||
# always padded to the maximum length. If we support other
|
else 0
|
||||||
# encoder-decoder models, this will need to be updated if we
|
)
|
||||||
# want to only allocate what is needed.
|
|
||||||
num_encoder_tokens = (
|
|
||||||
self.scheduler_config.max_num_encoder_input_tokens
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
num_encoder_tokens = 0
|
|
||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
|
|||||||
@ -21,8 +21,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import outlines_core as oc
|
import outlines_core as oc
|
||||||
import transformers.file_utils as file_utils
|
import transformers.file_utils as file_utils
|
||||||
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
|
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
|
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||||
|
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
@ -30,10 +30,8 @@ else:
|
|||||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||||
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
|
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
|
||||||
tokenization_gpt2 = LazyLoader(
|
bytes_to_unicode = LazyLoader(
|
||||||
"tokenization_gpt2",
|
"bytes_to_unicode", globals(), "transformers.convert_slow_tokenizer"
|
||||||
globals(),
|
|
||||||
"transformers.models.gpt2.tokenization_gpt2",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
TokenizerLike = object
|
TokenizerLike = object
|
||||||
@ -204,7 +202,7 @@ def _reduced_vocabulary(
|
|||||||
A Dict of token string -> equivalent token ids
|
A Dict of token string -> equivalent token ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()}
|
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
|
||||||
|
|
||||||
def convert_token_to_string(token: str) -> str:
|
def convert_token_to_string(token: str) -> str:
|
||||||
string = tokenizer.convert_tokens_to_string([token])
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|||||||
@ -145,12 +145,20 @@ class WorkspaceManager:
|
|||||||
|
|
||||||
for ubatch_id in range(self._num_ubatches):
|
for ubatch_id in range(self._num_ubatches):
|
||||||
current_workspace = self._current_workspaces[ubatch_id]
|
current_workspace = self._current_workspaces[ubatch_id]
|
||||||
if current_workspace is None:
|
if (
|
||||||
|
current_workspace is None
|
||||||
|
or self._workspace_size_bytes(current_workspace) < required_bytes
|
||||||
|
):
|
||||||
|
# Delete old tensor before allocating new one to avoid
|
||||||
|
# memory spike from resize_(). resize_() allocates new
|
||||||
|
# memory before freeing old, which can cause OOM.
|
||||||
|
# Must clear the list reference first since local var
|
||||||
|
# is just a copy of the reference.
|
||||||
|
self._current_workspaces[ubatch_id] = None
|
||||||
|
del current_workspace
|
||||||
self._current_workspaces[ubatch_id] = torch.empty(
|
self._current_workspaces[ubatch_id] = torch.empty(
|
||||||
(required_bytes,), dtype=torch.uint8, device=self._device
|
(required_bytes,), dtype=torch.uint8, device=self._device
|
||||||
)
|
)
|
||||||
elif self._workspace_size_bytes(current_workspace) < required_bytes:
|
|
||||||
current_workspace.resize_(required_bytes)
|
|
||||||
|
|
||||||
if envs.VLLM_DEBUG_WORKSPACE:
|
if envs.VLLM_DEBUG_WORKSPACE:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user