feat: add inference.empty_cache_per_design flag to reduce CUDA allocator fragmentation#451
Conversation
|
This failure is pre-existing and unrelated to this PR. CI metadata:
The ~6 minutes was entirely environment setup (virtualenv creation + package installation). The actual test crashed instantly at module import — before any import dgl torchdata.datapipes was removed in torchdata >= 0.7.0. The DGL version in the benchmark environment pulls in graphbolt, which requires it. This is a dependency pin The CI report itself attributes the breakage to commit №20 by Junior Martins, not to this PR. The second traceback (shutil.move: FileNotFoundError: tests/outputs) Recommended fix for CI: pin dgl < 2.0 or torchdata < 0.7.0 in the ubuntu-20.04.clang.python39.rfd test environment to restore the torchdata.datapipes API. This PR |
|
@lyskov Could you please review this? |
|
The CI failures are pre-existing and unrelated to this PR. Root cause: RuntimeError: Numpy is not available in rfdiffusion/igso3.py:93 during Diffuser.init(). NumPy's C extension (_ARRAY_API) fails to load in the CI environment (Python 3.9 venv) — this is visible as a UserWarning: Failed to initialize NumPy: _ARRAY_API not found at the very start of the log, before any test logic runs. The crash propagates up through IGSO3.init() → Diffuser.init() → SelfConditioning.initialize() → sampler_selector(), which is why line 54 of run_inference.py appears as the reported crash site — it is simply the outermost frame. This PR's changes are post-loop cleanup at lines 191–201 of run_inference.py, which are only reached after a design completes successfully. They are unreachable from this failure. Running the same test suite against upstream main without these changes produces the identical error. |
Problem
When running RFdiffusion with variable-length contigs (e.g.
contigmap.contigs=[A1-469/0 1-50]) over hundreds or thousands of designs, per-worker VRAM grows steadily from ~7 GB to 10–13 GB per process. This limits how many workers can run in parallel on a single GPU before exhausting VRAM.Root cause: PyTorch's CUDA caching allocator accumulates fragmented memory blocks across designs. With variable-length contigs each design allocates differently-sized tensors; freed blocks are cached but cannot be reused for different-sized allocations, causing steady VRAM growth.
Fix
Add an optional
inference.empty_cache_per_designflag (defaultFalse, opt-in) that callstorch.cuda.empty_cache()at the end of each design iteration. This releases all unused cached CUDA memory blocks back to the CUDA memory manager, keeping each worker near its initial VRAM footprint for the full run.Changes
config/inference/base.yamlscripts/run_inference.py— after the trajectory/PDB write block, beforelog.info:Measured impact
Tested on NVIDIA RTX 5090 32 GB running a long PPI campaign with variable-length contigs:
empty_cache_per_design=TrueThis allowed raising the number of parallel workers from 3 to 5 on a 32 GB GPU.
Why opt-in
torch.cuda.empty_cache()adds a small per-design overhead (~1–2 ms) and is only beneficial for long runs with variable-length contigs. For short runs or fixed-length designs there is no fragmentation issue, so the default remainsFalseto preserve existing behavior.Testing
All 20 applicable tests in
tests/test_diffusion.pypass with this change. The one skipped test (design_ppi_scaffolded) fails due to a missingppi_scaffolds/directory in the test fixture — a pre-existing issue unrelated to this PR.Notes
writepdb) and the optional trajectory block — every consumer ofdenoised_xyz_stack/px0_xyz_stackhas already finished before the cache is cleared.