diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c47202099ac60..34f5fedcf36e8 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -884,9 +884,8 @@ def _sdpa_attention( for i, seq_len in enumerate(seq_lens): end = start + seq_len - with torch.backends.cuda.sdp_kernel(enable_math=True, - enable_flash=False, - enable_mem_efficient=False): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH): sub_out = torch.nn.functional.scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 22a45b60ca399..8164d919ca8b4 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -272,7 +272,9 @@ class LoRAModel(AdapterModel): f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." f" Please verify that the loaded LoRA module is correct") - tensors = torch.load(lora_bin_file_path, map_location=device) + tensors = torch.load(lora_bin_file_path, + map_location=device, + weights_only=True) else: raise ValueError(f"{lora_dir} doesn't contain tensors") diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index a2b4949496897..0ea296b2f93d1 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -63,8 +63,8 @@ def _cast_if_autocast_enabled(*args): if not torch.is_autocast_enabled(): return args else: - return torch.cuda.amp.autocast_mode._cast( - args, torch.get_autocast_gpu_dtype()) + return torch.amp.autocast_mode._cast( + args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype()) class NemotronLayerNorm1P(nn.LayerNorm): @@ -89,7 +89,7 @@ class NemotronLayerNorm1P(nn.LayerNorm): residual = x args = _cast_if_autocast_enabled(x, self.normalized_shape, self.weight + 1, self.bias, self.eps) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x = torch.nn.functional.layer_norm(*args) return x if residual is None else (x, residual) diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index ca00207a9b6f7..9f08a1c4c6f5a 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -1766,9 +1766,12 @@ class MultiHeadedAttention(nn.Module): if mask.dtype != q.dtype: attn_mask = attn_mask.to(q.dtype) - with torch.backends.cuda.sdp_kernel(enable_flash=True, - enable_math=True, - enable_mem_efficient=True): + with torch.nn.attention.sdpa_kernel([ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + ]): x = torch.nn.functional.scaled_dot_product_attention( q, k, diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 255fac30bd78a..0c5a84c6508a1 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -149,7 +149,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): return self.load_bytes(base64.b64decode(data)) def load_file(self, filepath: Path) -> torch.Tensor: - return torch.load(filepath) + return torch.load(filepath, weights_only=True) def encode_base64(self, media: torch.Tensor) -> str: return base64.b64encode(media.numpy()).decode('utf-8')