diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 066cec41b7..1db4ce9558 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -232,6 +232,7 @@ def __init__( reset_ops_id: bool = True, track_meta: bool = False, weights_only: bool = True, + in_memory: bool = False, ) -> None: """ Args: @@ -273,6 +274,10 @@ def __init__( other safe objects. Setting this to `False` is required for loading `MetaTensor` objects saved with `track_meta=True`, however this creates the possibility of remote code execution through `torch.load` so be aware of the security implications of doing so. + in_memory: if `True`, keep the pre-processed data in an in-memory dictionary after first access. + This combines the benefits of persistent storage (data survives restarts) with faster RAM access. + When data is accessed, it is first loaded from disk cache and then stored in memory. + Default to `False`. Raises: ValueError: When both `track_meta=True` and `weights_only=True`, since this combination @@ -299,6 +304,13 @@ def __init__( ) self.track_meta = track_meta self.weights_only = weights_only + self.in_memory = in_memory + self._memory_cache: dict[str, Any] = {} + + @property + def memory_cache_size(self) -> int: + """Return the number of items currently stored in the in-memory cache.""" + return len(self._memory_cache) def set_transform_hash(self, hash_xform_func: Callable[..., bytes]): """Get hashable transforms, and then hash them. Hashable transforms @@ -326,6 +338,7 @@ def set_data(self, data: Sequence): """ self.data = data + self._memory_cache = {} if self.cache_dir is not None and self.cache_dir.exists(): shutil.rmtree(self.cache_dir, ignore_errors=True) self.cache_dir.mkdir(parents=True, exist_ok=True) @@ -389,14 +402,24 @@ def _cachecheck(self, item_transformed): """ hashfile = None + # compute cache key once for both disk and memory caching + data_item_md5 = self.hash_func(item_transformed).decode("utf-8") + data_item_md5 += self.transform_hash + cache_key = f"{data_item_md5}.pt" + if self.cache_dir is not None: - data_item_md5 = self.hash_func(item_transformed).decode("utf-8") - data_item_md5 += self.transform_hash - hashfile = self.cache_dir / f"{data_item_md5}.pt" + hashfile = self.cache_dir / cache_key + + # check in-memory cache first + if self.in_memory and cache_key in self._memory_cache: + return self._memory_cache[cache_key] if hashfile is not None and hashfile.is_file(): # cache hit try: - return torch.load(hashfile, weights_only=self.weights_only) + _item_transformed = torch.load(hashfile, weights_only=self.weights_only) + if self.in_memory: + self._memory_cache[cache_key] = _item_transformed + return _item_transformed except PermissionError as e: if sys.platform != "win32": raise e @@ -409,6 +432,8 @@ def _cachecheck(self, item_transformed): _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed if hashfile is None: + if self.in_memory: + self._memory_cache[cache_key] = _item_transformed return _item_transformed try: # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation @@ -431,6 +456,8 @@ def _cachecheck(self, item_transformed): pass except PermissionError: # project-monai/monai issue #3613 pass + if self.in_memory: + self._memory_cache[cache_key] = _item_transformed return _item_transformed def _transform(self, index: int): diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index ca62cdb184..9912904502 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -15,6 +15,7 @@ import os import tempfile import unittest +from pathlib import Path import nibabel as nib import numpy as np @@ -200,6 +201,111 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er im = test_dataset[0]["image"] self.assertIsInstance(im, expected_type) + def test_in_memory_cache(self): + """Test in_memory caching feature that combines persistent storage with RAM caching.""" + items = [[list(range(i))] for i in range(5)] + + with tempfile.TemporaryDirectory() as tempdir: + # First, create the persistent cache + ds1 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=False) + # Access all items to populate disk cache + _ = list(ds1) + + # Now create a new dataset with in_memory=True + ds2 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True) + + # Memory cache should be empty initially + self.assertEqual(ds2.memory_cache_size, 0) + + # Access items - they should be loaded from disk and cached in memory + _ = ds2[0] + self.assertEqual(ds2.memory_cache_size, 1) + + _ = ds2[1] + self.assertEqual(ds2.memory_cache_size, 2) + + # Access all items + _ = list(ds2) + self.assertEqual(ds2.memory_cache_size, 5) + + # Accessing same item again should use memory cache (same result) + result1 = ds2[0] + result2 = ds2[0] + self.assertEqual(result1, result2) + + # Test set_data clears in-memory cache + ds2.set_data(items[:3]) + self.assertEqual(ds2.memory_cache_size, 0) + + def test_in_memory_without_cache_dir(self): + """Test in_memory caching works even without a cache_dir (pure RAM cache).""" + items = [[list(range(i))] for i in range(3)] + + ds = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=None, in_memory=True) + + # Memory cache should be empty initially + self.assertEqual(ds.memory_cache_size, 0) + + # Access items - they should be cached in memory + _ = ds[0] + self.assertEqual(ds.memory_cache_size, 1) + + _ = list(ds) + self.assertEqual(ds.memory_cache_size, 3) + + def test_automatic_hybrid_caching(self): + """ + Test that in_memory=True provides automatic hybrid caching: + - ALL samples automatically persist to disk + - ALL samples automatically cache to RAM after first access + - No manual specification of which samples go where (unlike torchdatasets) + - Simulates restart scenario: disk cache survives, RAM cache rebuilds automatically + """ + items = [[list(range(i))] for i in range(5)] + + with tempfile.TemporaryDirectory() as tempdir: + # === First "session": populate both disk and RAM cache === + ds1 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True) + + # Access all items - should automatically cache to BOTH disk AND RAM + for i in range(len(items)): + _ = ds1[i] + + # Verify: ALL samples are in RAM (automatic, no manual specification) + self.assertEqual(ds1.memory_cache_size, 5) + + # Verify: ALL samples are on disk (count .pt files) + cache_files = list(Path(tempdir).glob("*.pt")) + self.assertEqual(len(cache_files), 5) + + # === Simulate "restart": new dataset instance, same cache_dir === + # This is the key benefit over CacheDataset - disk cache survives restart + ds2 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True) + + # RAM cache starts empty (simulating fresh process) + self.assertEqual(ds2.memory_cache_size, 0) + + # Access all items - should load from disk and automatically cache to RAM + results = [ds2[i] for i in range(len(items))] + + # Verify: ALL samples now in RAM again (automatic rebuild from disk) + self.assertEqual(ds2.memory_cache_size, 5) + + # Verify: Results are correct + for i, result in enumerate(results): + self.assertEqual(result, [list(range(i))]) + + # === Verify RAM cache provides fast repeated access === + # Accessing same items again should hit RAM cache (same objects) + for i in range(len(items)): + result1 = ds2[i] + result2 = ds2[i] + # Should return equivalent results + self.assertEqual(result1, result2) + + # RAM cache size unchanged (no duplicate entries) + self.assertEqual(ds2.memory_cache_size, 5) + if __name__ == "__main__": unittest.main()