| """ |
| Iterative Refinement via Langevin Noise-Refine Cycles. |
| |
| Inspired by ProDifEvo (Uehara et al., ICML 2025): repeatedly perturb and |
| refine structures through Q_theta gradient ascent. Each cycle adds noise |
| for diversity, then refines with Langevin dynamics toward higher selectivity. |
| |
| This allows designs to escape local optima and explore better selectivity |
| regions that single-shot generation cannot reach. |
| |
| Pipeline: |
| 1. Start from existing PXDesign outputs (seed structures) |
| 2. Align binder to reference receptor frames |
| 3. Run Langevin refinement with Q_theta gradient |
| 4. Score the refined output |
| 5. Repeat for K iterations, keeping best designs |
| |
| Usage: |
| python code/scripts/pxdesign_guidance/iterative_refinement.py \ |
| --input_dir results/pxdesign_guided/converted_pdbs \ |
| --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ |
| --ref_holo data/pdbs/cam_holo/3CLN.pdb \ |
| --ref_apo data/pdbs/cam_apo/1CFD.pdb \ |
| --n_iterations 3 --n_designs 10 \ |
| --gpu 6 |
| """ |
| import os |
| import sys |
| import json |
| import logging |
| import numpy as np |
| import torch |
| from glob import glob |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) |
| _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) |
|
|
| if _ALLO_CODE_DIR not in sys.path: |
| sys.path.insert(0, _ALLO_CODE_DIR) |
|
|
|
|
| def score_designs(pdb_paths, guidance): |
| """Score a list of PDB paths with Q_theta.""" |
| results = [] |
| for pdb_path in pdb_paths: |
| result = guidance.score_design(pdb_path) |
| if result is not None: |
| result['pdb_path'] = pdb_path |
| result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') |
| results.append(result) |
| return results |
|
|
|
|
| def run_langevin_cycle(pdb_paths, guidance, n_steps=50, step_size=0.005, |
| iteration=0, outdir='results/iterative_refinement'): |
| """Run Langevin refinement cycle on binder backbone coords using Q_theta. |
| |
| Uses guidance.dq (DifferentiableQTheta) for differentiable scoring. |
| Aligns binder to holo/apo reference frames for dual-state scoring. |
| """ |
| from utils.pdb_utils import (load_structure, get_residues, get_backbone_coords, |
| get_aa_indices, align_structures) |
|
|
| refined_results = [] |
| os.makedirs(outdir, exist_ok=True) |
|
|
| for pdb_path in pdb_paths: |
| try: |
| model = load_structure(pdb_path) |
| chains = {c.id: c for c in model.get_chains()} |
|
|
| binder_chain = None |
| for cid in sorted(chains.keys()): |
| if cid != 'A': |
| binder_chain = cid |
| break |
| if binder_chain is None: |
| continue |
|
|
| rec_res = get_residues(chains['A']) |
| if not rec_res: |
| rec_res = get_residues(chains['A'], only_standard=False) |
| binder_res = get_residues(chains[binder_chain]) |
| if not binder_res: |
| binder_res = get_residues(chains[binder_chain], only_standard=False) |
| if len(binder_res) < 5: |
| continue |
|
|
| binder_coords, binder_mask = get_backbone_coords(binder_res) |
| rec_coords, _ = get_backbone_coords(rec_res) |
|
|
| try: |
| aa_idx = get_aa_indices(binder_res) |
| except Exception: |
| aa_idx = np.zeros(len(binder_res), dtype=np.int64) |
|
|
| |
| rec_ca = rec_coords[:, 1, :] |
| ref_holo_ca = guidance.ref_holo_ca.cpu().numpy() |
| ref_apo_ca = guidance.ref_apo_ca.cpu().numpy() |
| n_h = min(len(rec_ca), len(ref_holo_ca)) |
| n_a = min(len(rec_ca), len(ref_apo_ca)) |
| if n_h < 5 or n_a < 5: |
| continue |
|
|
| _, R_h = align_structures(rec_ca[:n_h], ref_holo_ca[:n_h]) |
| center_h = rec_ca[:n_h].mean(0) |
| ref_center_h = ref_holo_ca[:n_h].mean(0) |
| aligned_holo = (binder_coords.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h |
| aligned_holo = aligned_holo.reshape(-1, 4, 3) |
|
|
| _, R_a = align_structures(rec_ca[:n_a], ref_apo_ca[:n_a]) |
| center_a = rec_ca[:n_a].mean(0) |
| ref_center_a = ref_apo_ca[:n_a].mean(0) |
|
|
| device = guidance.device |
| dq = guidance.dq |
|
|
| |
| R_h_t = torch.from_numpy(R_h).float().to(device) |
| R_a_t = torch.from_numpy(R_a).float().to(device) |
| center_h_t = torch.from_numpy(center_h).float().to(device) |
| ref_center_h_t = torch.from_numpy(ref_center_h).float().to(device) |
| center_a_t = torch.from_numpy(center_a).float().to(device) |
| ref_center_a_t = torch.from_numpy(ref_center_a).float().to(device) |
|
|
| |
| coords_t = torch.from_numpy(aligned_holo.copy()).float().to(device) |
| mask_t = torch.from_numpy(binder_mask).bool().to(device) |
| aa_t = torch.from_numpy(aa_idx).long().to(device) |
|
|
| |
| noise = torch.randn_like(coords_t) * 0.05 |
| coords_t = coords_t + noise |
|
|
| best_margin = -float('inf') |
| best_coords = coords_t.clone() |
|
|
| def project_bond_lengths(coords, target_dist=3.8, n_iters=5): |
| """Project CA-CA distances to target_dist via SHAKE-like iteration.""" |
| with torch.no_grad(): |
| for _ in range(n_iters): |
| ca = coords[:, 1, :].clone() |
| for i in range(len(ca) - 1): |
| delta = ca[i+1] - ca[i] |
| d = delta.norm() |
| if d < 1e-6: |
| continue |
| correction = 0.5 * (d - target_dist) / d * delta |
| coords[i, :, :] += correction.unsqueeze(0) |
| coords[i+1, :, :] -= correction.unsqueeze(0) |
| return coords |
|
|
| for step in range(n_steps): |
| coords_t = coords_t.detach().requires_grad_(True) |
|
|
| with torch.enable_grad(): |
| q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t, |
| receptor_label='holo') |
|
|
| |
| flat_t = coords_t.reshape(-1, 3) |
| original = (flat_t - ref_center_h_t) @ R_h_t + center_h_t |
| apo_aligned = (original - center_a_t) @ R_a_t.T + ref_center_a_t |
| coords_apo = apo_aligned.reshape(-1, 4, 3) |
|
|
| q_apo = dq.score(coords_apo, mask_t, binder_aa_idx=aa_t, |
| receptor_label='apo') |
| margin = q_holo - q_apo |
| margin.backward() |
|
|
| grad = coords_t.grad |
| if grad is None or torch.isnan(grad).any(): |
| continue |
|
|
| grad_norm = grad.norm().clamp(min=1e-8) |
|
|
| if margin.item() > best_margin: |
| best_margin = margin.item() |
| best_coords = coords_t.detach().clone() |
|
|
| if step % 10 == 0: |
| logger.info(f" [{os.path.basename(pdb_path)}] Step {step}: " |
| f"Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} " |
| f"S={margin.item():.3f} |g|={grad_norm.item():.4f}") |
|
|
| with torch.no_grad(): |
| coords_t = coords_t + step_size * grad / grad_norm |
| |
| noise_scale = step_size * 0.05 * (1 - step / n_steps) |
| coords_t = coords_t + noise_scale * torch.randn_like(coords_t) |
| |
| coords_t = project_bond_lengths(coords_t) |
|
|
| |
| final_coords = best_coords.detach().cpu().numpy() |
| basename = os.path.basename(pdb_path).replace('.pdb', '') |
| out_path = os.path.join(outdir, f'{basename}_iter{iteration}.pdb') |
|
|
| atom_names = [' N ', ' CA ', ' C ', ' O '] |
| elements = ['N', 'C', 'C', 'O'] |
| with open(out_path, 'w') as f: |
| atom_num = 1 |
| for i in range(len(final_coords)): |
| if not binder_mask[i]: |
| continue |
| for j, (aname, elem) in enumerate(zip(atom_names, elements)): |
| x, y, z = final_coords[i, j] |
| f.write(f"ATOM {atom_num:5d} {aname} ALA B{i+1:4d} " |
| f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n") |
| atom_num += 1 |
| f.write("END\n") |
|
|
| |
| result = guidance.score_design(out_path) |
| if result is not None: |
| result['pdb_path'] = out_path |
| result['iteration'] = iteration |
| result['best_margin_during_opt'] = best_margin |
| refined_results.append(result) |
| logger.info(f" -> Refined: S={result['margin']:.3f} " |
| f"(best during opt: {best_margin:.3f})") |
|
|
| except Exception as e: |
| logger.warning(f"Failed to refine {pdb_path}: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| return refined_results |
|
|
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input_dir', |
| default='results/pxdesign_guided/converted_pdbs') |
| parser.add_argument('--qtheta_checkpoint', |
| default='results/checkpoints_cam_v3/best_phase2.pt') |
| parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') |
| parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') |
| parser.add_argument('--ref_chain', default='A') |
| parser.add_argument('--n_iterations', type=int, default=4, |
| help='Number of refine cycles') |
| parser.add_argument('--n_designs', type=int, default=20, |
| help='Number of designs to refine') |
| parser.add_argument('--n_steps', type=int, default=50, |
| help='Langevin steps per iteration') |
| parser.add_argument('--step_size', type=float, default=0.005) |
| parser.add_argument('--gpu', type=int, default=6) |
| parser.add_argument('--outdir', default='results/iterative_refinement') |
| args = parser.parse_args() |
|
|
| os.chdir(_ALLO_ROOT) |
|
|
| from scripts.pxdesign_guidance.qtheta_pxdesign import QThetaPXDesignGuidance |
|
|
| outdir = args.outdir |
| os.makedirs(outdir, exist_ok=True) |
|
|
| |
| guidance = QThetaPXDesignGuidance( |
| checkpoint=args.qtheta_checkpoint, |
| ref_holo=args.ref_holo, |
| ref_apo=args.ref_apo, |
| ref_chain=args.ref_chain, |
| device=f'cuda:{args.gpu}', |
| ) |
| guidance._lazy_init() |
|
|
| |
| input_pdbs = sorted(glob(os.path.join(args.input_dir, '*.pdb')))[:args.n_designs] |
| logger.info(f"Selected {len(input_pdbs)} designs for iterative refinement") |
|
|
| |
| logger.info("Scoring initial designs...") |
| initial_results = score_designs(input_pdbs, guidance) |
| initial_margins = [r['margin'] for r in initial_results] |
| logger.info(f"Initial: S={np.mean(initial_margins):.3f}\u00b1{np.std(initial_margins):.3f}") |
|
|
| all_iteration_results = {'initial': initial_results} |
|
|
| |
| current_pdbs = input_pdbs |
| for iteration in range(args.n_iterations): |
| logger.info(f"\n{'='*50}") |
| logger.info(f"Iteration {iteration + 1}/{args.n_iterations}") |
| logger.info(f"{'='*50}") |
|
|
| iter_results = run_langevin_cycle( |
| current_pdbs, guidance, |
| n_steps=args.n_steps, |
| step_size=args.step_size, |
| iteration=iteration, |
| outdir=outdir, |
| ) |
|
|
| if iter_results: |
| margins = [r['margin'] for r in iter_results] |
| logger.info(f"Iteration {iteration}: S={np.mean(margins):.3f}\u00b1{np.std(margins):.3f}") |
| all_iteration_results[f'iteration_{iteration}'] = iter_results |
|
|
| |
| current_pdbs = [r['pdb_path'] for r in iter_results] |
|
|
| |
| logger.info(f"\n{'='*60}") |
| logger.info("Iterative Refinement Summary") |
| logger.info(f"{'='*60}") |
| for key, results in all_iteration_results.items(): |
| if results: |
| margins = [r['margin'] for r in results] |
| logger.info(f"{key:15s}: S={np.mean(margins):.3f}\u00b1{np.std(margins):.3f}, " |
| f"N={len(results)}, S>0={100*np.mean([m>0 for m in margins]):.0f}%") |
|
|
| |
| out_path = os.path.join(outdir, 'iterative_refinement_summary.json') |
| summary = {} |
| for key, results in all_iteration_results.items(): |
| if results: |
| margins = [r['margin'] for r in results] |
| summary[key] = { |
| 'n': len(results), |
| 'margin_mean': float(np.mean(margins)), |
| 'margin_std': float(np.std(margins)), |
| 'margin_max': float(np.max(margins)), |
| 'frac_positive': float(np.mean([m > 0 for m in margins])), |
| } |
| with open(out_path, 'w') as f: |
| json.dump(summary, f, indent=2) |
| logger.info(f"\nSaved to {out_path}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|