「view this page in B3 βῆτα server」

Revisions №79

branch: RFdiffusion:commits 「№79」
Commited by: Clay Moore
GitHub Pull Request link: 「№454」
Merge: 「2d0c003df」「63f0e71f4」  code diff
Scheduled at: 2026-05-19 16:31:57.429723
rfd

Pull Request №454 RosettaCommons/RFdiffusion/main ← mooreneural/RFdiffusion/main Merge: 2d0c003df46b9db41d119321f15403dec3716cd9←63f0e71f4edd2d79c46f04b603fe6e628680418c perf/accuracy: Flash Attention, torch-native SO(3), cosine schedule, DDIM, analytical g(t), acos fix ---------------- Merge commit message: perf/accuracy: Flash Attention, torch-native SO3, cosine schedule, DDIM, analytical g(t) Attention (Attention_module.py): - Replace hand-rolled einsum attention with F.scaled_dot_product_attention in Attention, AttentionWithBias, and MSAColAttention. Uses Flash Attention automatically when available on CUDA (20-40% speedup, O(1) memory). - AttentionWithBias passes the pairwise bias as attn_mask so it is folded into the fused kernel rather than materializing a separate attention matrix. SO3 diffusion (igso3.py, diffusion.py, inference/utils.py): - Add hat_batch(), Log_torch(), Exp_torch() -- on-device rotation ops using the Rodrigues formula. Eliminates all scipy CPU round-trips during inference. - Replace scipy_R calls in reverse_sample_vectorized() and diffuse_frames() with the new torch-native equivalents (stay on GPU, no .cpu()/.numpy() transfers). - Remove redundant scipy rotation normalization in get_next_frames(); rotation matrices from rigid_from_3_points are already orthogonal. Noise schedule (diffusion.py): - Add cosine schedule (Nichol & Dhariwal, 2021). Enabled via schedule_type="cosine"; b0/bT are ignored for this mode. - Analytical g(t) for linear schedule: eliminates a per-step autograd call. Formula: g(t) = sqrt(2 * sigma(t) * (min_b + t*(max_b - min_b))). IGSO3 cache (diffusion.py): - Add module-level _igso3_cache dict. Avoids repeated disk deserialization when multiple Diffuser objects are created in the same process (batch inference). DDIM sampling (inference/utils.py): - Add get_mu_xt_x0_ddim() implementing the deterministic DDIM update rule. - Wire ddim=True flag through Denoise.__init__() -> get_next_pose() -> get_next_ca(). Setting ddim=True produces deterministic, lower-variance trajectories and enables fewer-step inference at equivalent quality. Numerical stability (kinematics.py): - Clamp input to acos in get_ang() to [-1, 1] to prevent NaN from float rounding at exactly +/-1.

...
Test: ubuntu-20.04.clang.python39.rfd

 View log

Loading...

 View log in dialog  View log in log in separate window