Pull Request №445 RosettaCommons/RFdiffusion/main ← haoyu-haoyu/RFdiffusion/fix/replace-deprecated-torch-cuda-amp
Merge: 9535f1938203a24937d7dadf0cb831d02cb5fc0e←a7cc837a05b7944bf06c22795d390adf8821b681
fix: replace deprecated torch.cuda.amp with torch.amp
----------------
Merge commit message:
fix: replace deprecated torch.cuda.amp with torch.amp
`torch.cuda.amp.autocast`, `torch.cuda.amp.GradScaler` were deprecated
in PyTorch 1.13 and will be removed in a future release. Replace with
the device-explicit `torch.amp.autocast('cuda', ...)` and
`torch.amp.GradScaler('cuda', ...)` equivalents.
Files changed:
- rfdiffusion/Track_module.py (decorator on Str2Str.forward)
- env/SE3Transformer/se3_transformer/runtime/inference.py
- env/SE3Transformer/se3_transformer/runtime/training.py (2 instances)