Merge branch 'main' into mlm-full-lora-support

This commit is contained in:
Jee Jee Li 2025-12-17 00:33:42 +08:00 committed by GitHub
commit 94dce5c3d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 182 additions and 123 deletions

View File

@ -1223,6 +1223,8 @@ steps:
# FIXIT: find out which code initialize cuda before running the test # FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it # before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
# Alot of these tests are on the edge of OOMing
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# There is some Tensor Parallelism related processing logic in LoRA that # There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation. # requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_chatglm3_tp.py

View File

@ -3,7 +3,7 @@
from collections import OrderedDict from collections import OrderedDict
from typing import NamedTuple from typing import NamedTuple
from unittest.mock import patch from unittest.mock import MagicMock, patch
import pytest import pytest
from huggingface_hub.utils import HfHubHTTPError from huggingface_hub.utils import HfHubHTTPError
@ -194,5 +194,8 @@ def test_get_adapter_absolute_path_huggingface_error(
# Hugging Face model identifier with download error # Hugging Face model identifier with download error
path = "org/repo" path = "org/repo"
mock_exist.return_value = False mock_exist.return_value = False
mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info") mock_snapshot_download.side_effect = HfHubHTTPError(
"failed to query model info",
response=MagicMock(),
)
assert get_adapter_absolute_path(path) == path assert get_adapter_absolute_path(path) == path

View File

@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
"mm_encoder_attn_backend", "mm_encoder_attn_backend",
[None] + current_platform.get_supported_vit_attn_backends(), [None] + current_platform.get_supported_vit_attn_backends(),
) )
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_vit_backend_functionality( def test_vit_backend_functionality(
model_key: str, model_key: str,

View File

@ -642,48 +642,130 @@ _OPS_REGISTERED = False
class rocm_aiter_ops: class rocm_aiter_ops:
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
This class centralizes the import and registration of AITER ops,
and provides a unified interface for checking if AITER is enabled.
Operations are only available on supported gfx9
architectures when aiter is installed.
The class uses environment variables to control which features are enabled,
allowing fine-grained control over which AITER optimizations are used.
Environment Variables:
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
Note:
The environment variables are assigned when the module is imported,
so you can't change the environment variables after the module is imported.
This is done out of performance consideration. Accessing environment variables
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
You can call the refresh_env_variables() function to reload the env variables
after monkey patching the env variables in the unit test.
Check Functions:
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
(3) aiter library is installed. The check function then also verifies
the corresponding environment variable is enabled.
i.e. ___
is_enabled() == current_platform.is_rocm() and | checked by
current_platform.is_on_gfx9() and | @if_aiter_supported
IS_AITER_FOUND and _______________|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
Example:
from vllm._aiter_ops import rocm_aiter_ops
# Check if aiter is enabled before using operations
if rocm_aiter_ops.is_enabled():
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
Operations:
- RMS normalization: rms_norm, rms_norm2d_with_add
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
- Fused MoE: fused_moe, asm_moe_tkw1
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
- MLA decode: mla_decode_fwd
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
"""
# Check if the env variable is set
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
# TODO: Consolidate under _LINEAR_ENABLED
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@classmethod
def refresh_env_variables(cls):
"""
Since the environment variables are assigned when the module is imported,
This is a helper function to reload all the env variables from
the environment variables.
for example, after monkey patching the env variables in the unit test,
you can call this function to reload the env variables.
"""
cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_enabled(cls) -> bool: def is_enabled(cls) -> bool:
"""Verifies device specs and availability of aiter main env variable."""
return cls._AITER_ENABLED return cls._AITER_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_linear_enabled(cls) -> bool: def is_linear_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._LINEAR_ENABLED return cls._AITER_ENABLED and cls._LINEAR_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool: def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() return cls.is_linear_enabled()
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_rmsnorm_enabled(cls) -> bool: def is_rmsnorm_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_fused_moe_enabled(cls) -> bool: def is_fused_moe_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._FMOE_ENABLED return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod @classmethod
@ -694,25 +776,16 @@ class rocm_aiter_ops:
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_mla_enabled(cls) -> bool: def is_mla_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MLA_ENABLED return cls._AITER_ENABLED and cls._MLA_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_mha_enabled(cls) -> bool: def is_mha_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MHA_ENABLED return cls._AITER_ENABLED and cls._MHA_ENABLED
@classmethod
@if_aiter_supported
def is_pa_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_triton_unified_attn_enabled(cls) -> bool: def is_triton_unified_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
@classmethod @classmethod

View File

@ -937,7 +937,7 @@ class CompilationConfig:
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
): ):
logger.warning_once( logger.warning_once(
"Using piecewise compilation with empty splitting_ops" "Using piecewise cudagraph with empty splitting_ops"
) )
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once( logger.warning_once(

View File

@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE):
self._shared_experts = shared_experts self._shared_experts = shared_experts
# Disable shared expert overlap if: # Disable shared expert overlap if:
# - we are using eplb, because of correctness issues # - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain # - we are using flashinfer with DP, since there nothint to gain
# - we are using marlin kernels # - we are using marlin kernels
backend = self.moe_parallel_config.all2all_backend
self.use_overlapped = ( self.use_overlapped = (
use_overlapped use_overlapped
and not ( and not (
# TODO(wentao): find the root cause and remove this condition (self.enable_eplb and backend != "allgather_reducescatter")
self.enable_eplb
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
) )
and self._shared_experts is not None and self._shared_experts is not None

View File

@ -469,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
logger.debug_once("Finished shuffling weights for TRT-LLM MOE") logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter( layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False gemm1_weights_fp4_shuffled, requires_grad=False
) )
layer.gemm2_weights_fp4_shuffled = Parameter( layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
gemm2_weights_fp4_shuffled, requires_grad=False layer.w13_weight_scale = Parameter(
)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False gemm1_scales_fp4_shuffled, requires_grad=False
) )
layer.gemm2_scales_fp4_shuffled = Parameter( layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False gemm2_scales_fp4_shuffled, requires_grad=False
) )
@ -487,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False, requires_grad=False,
) )
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
else: else:
# swizzle weight scales # swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(

View File

@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
logger.debug_once("Finished shuffling weights for TRT-LLM MOE") logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter( layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False gemm1_weights_fp4_shuffled, requires_grad=False
) )
layer.gemm2_weights_fp4_shuffled = Parameter( layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
gemm2_weights_fp4_shuffled, requires_grad=False layer.w13_weight_scale = Parameter(
)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False gemm1_scales_fp4_shuffled, requires_grad=False
) )
layer.gemm2_scales_fp4_shuffled = Parameter( layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False gemm2_scales_fp4_shuffled, requires_grad=False
) )
@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False, requires_grad=False,
) )
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
elif self.use_marlin: elif self.use_marlin:
# Marlin processing # Marlin processing
prepare_moe_fp4_layer_for_marlin(layer) prepare_moe_fp4_layer_for_marlin(layer)

View File

@ -301,18 +301,14 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view( hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn torch.float8_e4m3fn
).flatten(), ).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm1_bias=None, gemm1_bias=None,
gemm1_alpha=None, gemm1_alpha=None,
gemm1_beta=None, gemm1_beta=None,
gemm1_clamp_limit=None, gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm2_bias=None, gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data, output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data, output1_scale_gate_scalar=layer.g1_alphas.data,
@ -380,18 +376,14 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view( hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn torch.float8_e4m3fn
).flatten(), ).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm1_bias=None, gemm1_bias=None,
gemm1_alpha=None, gemm1_alpha=None,
gemm1_beta=None, gemm1_beta=None,
gemm1_clamp_limit=None, gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm2_bias=None, gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data, output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data, output1_scale_gate_scalar=layer.g1_alphas.data,

View File

@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
"position_ids", "position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0), torch.arange(config.max_position_embeddings).unsqueeze(0),
) )
self.position_embedding_type = config.position_embedding_type self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type != "absolute": if self.position_embedding_type != "absolute":
raise ValueError( raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported" "Only 'absolute' position_embedding_type" + " is supported"

View File

@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
torch.arange(config.max_position_embeddings).unsqueeze(0), torch.arange(config.max_position_embeddings).unsqueeze(0),
) )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model( def _build_model(
self, vllm_config: VllmConfig, prefix: str = "" self, vllm_config: VllmConfig, prefix: str = ""
) -> BertModel | BertWithRope: ) -> BertModel | BertWithRope:
if vllm_config.model_config.hf_config.position_embedding_type == "rotary": hf_config = vllm_config.model_config.hf_config
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) kwargs = dict(vllm_config=vllm_config, prefix=prefix)
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
else: else:
return BertModel( return JinaRobertaModel(**kwargs)
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights) weights_list = list(weights)

View File

@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(hidden_states, attention_mask, output_attentions)
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
height, width = input_dimensions height, width = input_dimensions
for i, layer_module in enumerate(self.blocks): for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers): for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor | None = None, pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None, output_attentions: bool | None = None,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values) embedding_output, input_dimensions = self.embeddings(pixel_values)
@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
input_dimensions, input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )

View File

@ -5,6 +5,7 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import copy import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
) )
hidden_states = hidden_states + positions hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers: for layer in self.layers:
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -124,8 +124,6 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None, sinks: torch.Tensor | None = None,
) -> bool: ) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 1 and gqa_ratio <= 16) and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (rocm_aiter_ops.is_pa_attn_enabled())
and sinks is None and sinks is None
) )

View File

@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
qo_indptr: torch.Tensor | None = None qo_indptr: torch.Tensor | None = None
# The dtype of MLA out tensor # The dtype of MLA out tensor
attn_out_dtype: torch.dtype = torch.bfloat16 attn_out_dtype: torch.dtype = torch.bfloat16
# The max query output length: int
max_qo_len: int | None = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of: # TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945 # https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support: ClassVar[AttentionCGSupport] = ( _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
)
def __init__( def __init__(
self, self,
@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs, dtype=torch.int32, device=device max_num_reqs, dtype=torch.int32, device=device
) )
self.qo_indptr = torch.arange( self.qo_indptr = torch.zeros(
0, max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
) )
def _build_decode( def _build_decode(
@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens_device.cumsum(dim=0, dtype=torch.int32), seq_lens_device.cumsum(dim=0, dtype=torch.int32),
] ]
) )
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_qo_len = qo_len.max().item()
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0) num_actual_pages = paged_kv_indices.size(0)
@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_last_page_len[num_reqs:].fill_(1) self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
self.qo_indptr[: 1 + num_reqs].copy_(
query_start_loc_device, non_blocking=True
)
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
qo_indptr = self.qo_indptr[: 1 + num_reqs] qo_indptr = self.qo_indptr[: 1 + num_reqs]
else: else:
@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len=paged_kv_last_page_len, paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device, dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype, attn_out_dtype=self.decode_attn_out_dtype,
) )
@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
rocm_aiter_ops.mla_decode_fwd( rocm_aiter_ops.mla_decode_fwd(
q, q,
kv_buffer, kv_buffer,
o, o,
self.scale, self.scale,
attn_metadata.decode.qo_indptr, attn_metadata.decode.qo_indptr,
max_seqlen_qo, attn_metadata.decode.max_qo_len,
attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len, attn_metadata.decode.paged_kv_last_page_len,

View File

@ -187,6 +187,12 @@ class Scheduler(SchedulerInterface):
if self.is_encoder_decoder if self.is_encoder_decoder
else EncoderCacheManager(cache_size=encoder_cache_size) else EncoderCacheManager(cache_size=encoder_cache_size)
) )
# For encoder-decoder models, allocate the maximum number of tokens for Cross
# Attn blocks, as for Whisper its input is always padded to the maximum length.
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
self._num_encoder_max_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config)
)
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
self.use_eagle = False self.use_eagle = False
@ -568,17 +574,11 @@ class Scheduler(SchedulerInterface):
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
) )
# Determine if we need to allocate cross-attention blocks. num_encoder_tokens = (
if self.is_encoder_decoder and request.has_encoder_inputs: self._num_encoder_max_input_tokens
# TODO(russellb): For Whisper, we know that the input is if self.is_encoder_decoder and request.has_encoder_inputs
# always padded to the maximum length. If we support other else 0
# encoder-decoder models, this will need to be updated if we )
# want to only allocate what is needed.
num_encoder_tokens = (
self.scheduler_config.max_num_encoder_input_tokens
)
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,

View File

@ -21,8 +21,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
import outlines_core as oc import outlines_core as oc
import transformers.file_utils as file_utils import transformers.file_utils as file_utils
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
import xgrammar as xgr import xgrammar as xgr
from transformers.convert_slow_tokenizer import bytes_to_unicode
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
@ -30,10 +30,8 @@ else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
oc = LazyLoader("oc", globals(), "outlines_core") oc = LazyLoader("oc", globals(), "outlines_core")
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
tokenization_gpt2 = LazyLoader( bytes_to_unicode = LazyLoader(
"tokenization_gpt2", "bytes_to_unicode", globals(), "transformers.convert_slow_tokenizer"
globals(),
"transformers.models.gpt2.tokenization_gpt2",
) )
TokenizerLike = object TokenizerLike = object
@ -204,7 +202,7 @@ def _reduced_vocabulary(
A Dict of token string -> equivalent token ids A Dict of token string -> equivalent token ids
""" """
unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()} unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
def convert_token_to_string(token: str) -> str: def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token]) string = tokenizer.convert_tokens_to_string([token])

View File

@ -145,12 +145,20 @@ class WorkspaceManager:
for ubatch_id in range(self._num_ubatches): for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id] current_workspace = self._current_workspaces[ubatch_id]
if current_workspace is None: if (
current_workspace is None
or self._workspace_size_bytes(current_workspace) < required_bytes
):
# Delete old tensor before allocating new one to avoid
# memory spike from resize_(). resize_() allocates new
# memory before freeing old, which can cause OOM.
# Must clear the list reference first since local var
# is just a copy of the reference.
self._current_workspaces[ubatch_id] = None
del current_workspace
self._current_workspaces[ubatch_id] = torch.empty( self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device (required_bytes,), dtype=torch.uint8, device=self._device
) )
elif self._workspace_size_bytes(current_workspace) < required_bytes:
current_workspace.resize_(required_bytes)
if envs.VLLM_DEBUG_WORKSPACE: if envs.VLLM_DEBUG_WORKSPACE:
logger.info( logger.info(