Pull Request №454 RosettaCommons/RFdiffusion/main ← mooreneural/RFdiffusion/main
Merge: 2d0c003df46b9db41d119321f15403dec3716cd9←63f0e71f4edd2d79c46f04b603fe6e628680418c
perf/accuracy: Flash Attention, torch-native SO(3), cosine schedule, DDIM, analytical g(t), acos fix
----------------
Merge commit message:
perf/accuracy: Flash Attention, torch-native SO3, cosine schedule, DDIM, analytical g(t)
Attention (Attention_module.py):
- Replace hand-rolled einsum attention with F.scaled_dot_product_attention in
Attention, AttentionWithBias, and MSAColAttention. Uses Flash Attention
automatically when available on CUDA (20-40% speedup, O(1) memory).
- AttentionWithBias passes the pairwise bias as attn_mask so it is folded into
the fused kernel rather than materializing a separate attention matrix.
SO3 diffusion (igso3.py, diffusion.py, inference/utils.py):
- Add hat_batch(), Log_torch(), Exp_torch() -- on-device rotation ops using
the Rodrigues formula. Eliminates all scipy CPU round-trips during inference.
- Replace scipy_R calls in reverse_sample_vectorized() and diffuse_frames() with
the new torch-native equivalents (stay on GPU, no .cpu()/.numpy() transfers).
- Remove redundant scipy rotation normalization in get_next_frames(); rotation
matrices from rigid_from_3_points are already orthogonal.
Noise schedule (diffusion.py):
- Add cosine schedule (Nichol & Dhariwal, 2021). Enabled via
schedule_type="cosine"; b0/bT are ignored for this mode.
- Analytical g(t) for linear schedule: eliminates a per-step autograd call.
Formula: g(t) = sqrt(2 * sigma(t) * (min_b + t*(max_b - min_b))).
IGSO3 cache (diffusion.py):
- Add module-level _igso3_cache dict. Avoids repeated disk deserialization when
multiple Diffuser objects are created in the same process (batch inference).
DDIM sampling (inference/utils.py):
- Add get_mu_xt_x0_ddim() implementing the deterministic DDIM update rule.
- Wire ddim=True flag through Denoise.__init__() -> get_next_pose() -> get_next_ca().
Setting ddim=True produces deterministic, lower-variance trajectories and
enables fewer-step inference at equivalent quality.
Numerical stability (kinematics.py):
- Clamp input to acos in get_ang() to [-1, 1] to prevent NaN from float
rounding at exactly +/-1.