Apply torchfix (#15532)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever 2025-03-26 20:09:06 +08:00 committed by GitHub
parent cf5c8f1686
commit 1aa162e030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 15 additions and 11 deletions

View File

@ -884,9 +884,8 @@ def _sdpa_attention(
for i, seq_len in enumerate(seq_lens): for i, seq_len in enumerate(seq_lens):
end = start + seq_len end = start + seq_len
with torch.backends.cuda.sdp_kernel(enable_math=True, with torch.nn.attention.sdpa_kernel(
enable_flash=False, torch.nn.attention.SDPBackend.MATH):
enable_mem_efficient=False):
sub_out = torch.nn.functional.scaled_dot_product_attention( sub_out = torch.nn.functional.scaled_dot_product_attention(
query[:, start:end, :], query[:, start:end, :],
key[:, start:end, :], key[:, start:end, :],

View File

@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}." f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct") f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, map_location=device) tensors = torch.load(lora_bin_file_path,
map_location=device,
weights_only=True)
else: else:
raise ValueError(f"{lora_dir} doesn't contain tensors") raise ValueError(f"{lora_dir} doesn't contain tensors")

View File

@ -63,8 +63,8 @@ def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled(): if not torch.is_autocast_enabled():
return args return args
else: else:
return torch.cuda.amp.autocast_mode._cast( return torch.amp.autocast_mode._cast(
args, torch.get_autocast_gpu_dtype()) args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype())
class NemotronLayerNorm1P(nn.LayerNorm): class NemotronLayerNorm1P(nn.LayerNorm):
@ -89,7 +89,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
residual = x residual = x
args = _cast_if_autocast_enabled(x, self.normalized_shape, args = _cast_if_autocast_enabled(x, self.normalized_shape,
self.weight + 1, self.bias, self.eps) self.weight + 1, self.bias, self.eps)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x = torch.nn.functional.layer_norm(*args) x = torch.nn.functional.layer_norm(*args)
return x if residual is None else (x, residual) return x if residual is None else (x, residual)

View File

@ -1766,9 +1766,12 @@ class MultiHeadedAttention(nn.Module):
if mask.dtype != q.dtype: if mask.dtype != q.dtype:
attn_mask = attn_mask.to(q.dtype) attn_mask = attn_mask.to(q.dtype)
with torch.backends.cuda.sdp_kernel(enable_flash=True, with torch.nn.attention.sdpa_kernel([
enable_math=True, torch.nn.attention.SDPBackend.FLASH_ATTENTION,
enable_mem_efficient=True): torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
]):
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q, q,
k, k,

View File

@ -149,7 +149,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
return self.load_bytes(base64.b64decode(data)) return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> torch.Tensor: def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath) return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str: def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8') return base64.b64encode(media.numpy()).decode('utf-8')