Apply torchfix (#15532)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever 2025-03-26 20:09:06 +08:00 committed by GitHub
parent cf5c8f1686
commit 1aa162e030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 15 additions and 11 deletions

View File

@ -884,9 +884,8 @@ def _sdpa_attention(
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
with torch.backends.cuda.sdp_kernel(enable_math=True,
enable_flash=False,
enable_mem_efficient=False):
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.MATH):
sub_out = torch.nn.functional.scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],

View File

@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, map_location=device)
tensors = torch.load(lora_bin_file_path,
map_location=device,
weights_only=True)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")

View File

@ -63,8 +63,8 @@ def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(
args, torch.get_autocast_gpu_dtype())
return torch.amp.autocast_mode._cast(
args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype())
class NemotronLayerNorm1P(nn.LayerNorm):
@ -89,7 +89,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
residual = x
args = _cast_if_autocast_enabled(x, self.normalized_shape,
self.weight + 1, self.bias, self.eps)
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast("cuda", enabled=False):
x = torch.nn.functional.layer_norm(*args)
return x if residual is None else (x, residual)

View File

@ -1766,9 +1766,12 @@ class MultiHeadedAttention(nn.Module):
if mask.dtype != q.dtype:
attn_mask = attn_mask.to(q.dtype)
with torch.backends.cuda.sdp_kernel(enable_flash=True,
enable_math=True,
enable_mem_efficient=True):
with torch.nn.attention.sdpa_kernel([
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
]):
x = torch.nn.functional.scaled_dot_product_attention(
q,
k,

View File

@ -149,7 +149,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath)
return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8')