diff --git a/changes/3679.feature.md b/changes/3679.feature.md new file mode 100644 index 0000000000..cdf35382e4 --- /dev/null +++ b/changes/3679.feature.md @@ -0,0 +1,3 @@ +Adds a new in-memory storage backend called `ManagedMemoryStore`. Instances of `ManagedMemoryStore` +function similarly to `MemoryStore`, but instances of `ManagedMemoryStore` can be constructed from +a URL like `memory://store`. \ No newline at end of file diff --git a/docs/user-guide/arrays.md b/docs/user-guide/arrays.md index 1675c853fa..36392902b1 100644 --- a/docs/user-guide/arrays.md +++ b/docs/user-guide/arrays.md @@ -14,15 +14,14 @@ np.random.seed(0) ```python exec="true" session="arrays" source="above" result="ansi" import zarr -store = zarr.storage.MemoryStore() -z = zarr.create_array(store=store, shape=(10000, 10000), chunks=(1000, 1000), dtype='int32') +z = zarr.create_array(store="memory://arrays-demo", shape=(10000, 10000), chunks=(1000, 1000), dtype='int32') print(z) ``` The code above creates a 2-dimensional array of 32-bit integers with 10000 rows and 10000 columns, divided into chunks where each chunk has 1000 rows and 1000 -columns (and so there will be 100 chunks in total). The data is written to a -[`zarr.storage.MemoryStore`][] (e.g. an in-memory dict). See +columns (and so there will be 100 chunks in total). The data is written to an +in-memory store (see [`zarr.storage.MemoryStore`][] for more details). See [Persistent arrays](#persistent-arrays) for details on storing arrays in other stores, and see [Data types](data_types.md) for an in-depth look at the data types supported by Zarr. diff --git a/docs/user-guide/attributes.md b/docs/user-guide/attributes.md index 44d2f9fa87..d5961ed38a 100644 --- a/docs/user-guide/attributes.md +++ b/docs/user-guide/attributes.md @@ -3,10 +3,9 @@ Zarr arrays and groups support custom key/value attributes, which can be useful for storing application-specific metadata. For example: -```python exec="true" session="arrays" source="above" result="ansi" +```python exec="true" session="attributes" source="above" result="ansi" import zarr -store = zarr.storage.MemoryStore() -root = zarr.create_group(store=store) +root = zarr.create_group(store="memory://attributes-demo") root.attrs['foo'] = 'bar' z = root.create_array(name='zzz', shape=(10000, 10000), dtype='int32') z.attrs['baz'] = 42 @@ -14,22 +13,22 @@ z.attrs['qux'] = [1, 4, 7, 12] print(sorted(root.attrs)) ``` -```python exec="true" session="arrays" source="above" result="ansi" +```python exec="true" session="attributes" source="above" result="ansi" print('foo' in root.attrs) ``` -```python exec="true" session="arrays" source="above" result="ansi" +```python exec="true" session="attributes" source="above" result="ansi" print(root.attrs['foo']) ``` -```python exec="true" session="arrays" source="above" result="ansi" +```python exec="true" session="attributes" source="above" result="ansi" print(sorted(z.attrs)) ``` -```python exec="true" session="arrays" source="above" result="ansi" +```python exec="true" session="attributes" source="above" result="ansi" print(z.attrs['baz']) ``` -```python exec="true" session="arrays" source="above" result="ansi" +```python exec="true" session="attributes" source="above" result="ansi" print(z.attrs['qux']) ``` diff --git a/docs/user-guide/consolidated_metadata.md b/docs/user-guide/consolidated_metadata.md index d4fc9d6bab..c5cd31e5fc 100644 --- a/docs/user-guide/consolidated_metadata.md +++ b/docs/user-guide/consolidated_metadata.md @@ -27,8 +27,7 @@ import zarr import warnings warnings.filterwarnings("ignore", category=UserWarning) -store = zarr.storage.MemoryStore() -group = zarr.create_group(store=store) +group = zarr.create_group(store="memory://consolidated-metadata-demo") print(group) array = group.create_array(shape=(1,), name='a', dtype='float64') print(array) @@ -45,7 +44,7 @@ print(array) ``` ```python exec="true" session="consolidated_metadata" source="above" result="ansi" -result = zarr.consolidate_metadata(store) +result = zarr.consolidate_metadata("memory://consolidated-metadata-demo") print(result) ``` @@ -56,7 +55,7 @@ that can be used.: from pprint import pprint import io -consolidated = zarr.open_group(store=store) +consolidated = zarr.open_group(store="memory://consolidated-metadata-demo") consolidated_metadata = consolidated.metadata.consolidated_metadata.metadata # Note: pprint can be users without capturing the output regularly @@ -76,7 +75,7 @@ With nested groups, the consolidated metadata is available on the children, recu ```python exec="true" session="consolidated_metadata" source="above" result="ansi" child = group.create_group('child', attributes={'kind': 'child'}) grandchild = child.create_group('child', attributes={'kind': 'grandchild'}) -consolidated = zarr.consolidate_metadata(store) +consolidated = zarr.consolidate_metadata("memory://consolidated-metadata-demo") output = io.StringIO() pprint(consolidated['child'].metadata.consolidated_metadata, stream=output, width=60) diff --git a/docs/user-guide/gpu.md b/docs/user-guide/gpu.md index 3317bdf065..ff86263cf0 100644 --- a/docs/user-guide/gpu.md +++ b/docs/user-guide/gpu.md @@ -20,9 +20,8 @@ buffers used internally by Zarr via `enable_gpu()`. import zarr import cupy as cp zarr.config.enable_gpu() -store = zarr.storage.MemoryStore() z = zarr.create_array( - store=store, shape=(100, 100), chunks=(10, 10), dtype="float32", + store="memory://gpu-demo", shape=(100, 100), chunks=(10, 10), dtype="float32", ) type(z[:10, :10]) # cupy.ndarray diff --git a/docs/user-guide/groups.md b/docs/user-guide/groups.md index 57201216b6..4eff4a1680 100644 --- a/docs/user-guide/groups.md +++ b/docs/user-guide/groups.md @@ -8,8 +8,7 @@ To create a group, use the [`zarr.group`][] function: ```python exec="true" session="groups" source="above" result="ansi" import zarr -store = zarr.storage.MemoryStore() -root = zarr.create_group(store=store) +root = zarr.create_group(store="memory://groups-demo") print(root) ``` @@ -105,8 +104,7 @@ Diagnostic information about arrays and groups is available via the `info` property. E.g.: ```python exec="true" session="groups" source="above" result="ansi" -store = zarr.storage.MemoryStore() -root = zarr.group(store=store) +root = zarr.group(store="memory://diagnostics-demo") foo = root.create_group('foo') bar = foo.create_array(name='bar', shape=1000000, chunks=100000, dtype='int64') bar[:] = 42 diff --git a/src/zarr/storage/__init__.py b/src/zarr/storage/__init__.py index 00df50214f..f1bd1724af 100644 --- a/src/zarr/storage/__init__.py +++ b/src/zarr/storage/__init__.py @@ -8,7 +8,7 @@ from zarr.storage._fsspec import FsspecStore from zarr.storage._local import LocalStore from zarr.storage._logging import LoggingStore -from zarr.storage._memory import GpuMemoryStore, MemoryStore +from zarr.storage._memory import GpuMemoryStore, ManagedMemoryStore, MemoryStore from zarr.storage._obstore import ObjectStore from zarr.storage._wrapper import WrapperStore from zarr.storage._zip import ZipStore @@ -18,6 +18,7 @@ "GpuMemoryStore", "LocalStore", "LoggingStore", + "ManagedMemoryStore", "MemoryStore", "ObjectStore", "StoreLike", diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..6d57e23c32 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -17,8 +17,8 @@ ) from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError from zarr.storage._local import LocalStore -from zarr.storage._memory import MemoryStore -from zarr.storage._utils import normalize_path +from zarr.storage._memory import ManagedMemoryStore, MemoryStore +from zarr.storage._utils import _dereference_path, normalize_path _has_fsspec = importlib.util.find_spec("fsspec") if _has_fsspec: @@ -30,18 +30,6 @@ from zarr.core.buffer import BufferPrototype -def _dereference_path(root: str, path: str) -> str: - if not isinstance(root, str): - msg = f"{root=} is not a string ({type(root)=})" # type: ignore[unreachable] - raise TypeError(msg) - if not isinstance(path, str): - msg = f"{path=} is not a string ({type(path)=})" # type: ignore[unreachable] - raise TypeError(msg) - root = root.rstrip("/") - path = f"{root}/{path}" if root else path - return path.rstrip("/") - - class StorePath: """ Path-like interface for a Store. @@ -341,8 +329,17 @@ async def make_store( return await LocalStore.open(root=store_like, mode=mode, read_only=_read_only) elif isinstance(store_like, str): + # Check for memory:// URLs first + if store_like.startswith("memory://"): + # Parse the URL to extract name and path + url_without_scheme = store_like[len("memory://") :] + parts = url_without_scheme.split("/", 1) + name = parts[0] if parts[0] else None + path = parts[1] if len(parts) > 1 else "" + # Create or get the store - ManagedMemoryStore handles both cases + return ManagedMemoryStore(name=name, path=path, read_only=_read_only) # Either an FSSpec URI or a local filesystem path - if _is_fsspec_uri(store_like): + elif _is_fsspec_uri(store_like): return FsspecStore.from_url( store_like, storage_options=storage_options, read_only=_read_only ) @@ -418,6 +415,18 @@ async def make_store_path( "'path' was provided but is not used for FSMap store_like objects. Specify the path when creating the FSMap instance instead." ) + elif isinstance(store_like, str) and store_like.startswith("memory://"): + # Handle memory:// URLs specially + # Parse the URL to extract name and path + _read_only = mode == "r" + url_without_scheme = store_like[len("memory://") :] + parts = url_without_scheme.split("/", 1) + name = parts[0] if parts[0] else None + url_path = parts[1] if len(parts) > 1 else "" + # Create or get the store - ManagedMemoryStore handles both cases + memory_store = ManagedMemoryStore(name=name, path=url_path, read_only=_read_only) + return await StorePath.open(memory_store, path=path_normalized, mode=mode) + else: store = await make_store(store_like, mode=mode, storage_options=storage_options) return await StorePath.open(store, path=path_normalized, mode=mode) diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index f9e4ed375d..0549b6a300 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -16,7 +16,7 @@ ) from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning -from zarr.storage._common import _dereference_path +from zarr.storage._common import _dereference_path # type: ignore[attr-defined] if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..af4d9a2521 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os +import weakref from logging import getLogger from typing import TYPE_CHECKING, Any, Self @@ -7,7 +9,7 @@ from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map -from zarr.storage._utils import _normalize_byte_range_index +from zarr.storage._utils import _dereference_path, _normalize_byte_range_index, normalize_path if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, MutableMapping @@ -475,3 +477,393 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None # Convert to gpu.Buffer gpu_value = value if isinstance(value, gpu.Buffer) else gpu.Buffer.from_buffer(value) await super().set(key, gpu_value, byte_range=byte_range) + + +# ----------------------------------------------------------------------------- +# ManagedMemoryStore and its registry +# ----------------------------------------------------------------------------- +# ManagedMemoryStore owns the lifecycle of its backing dict, enabling proper +# weakref-based tracking. This allows memory:// URLs to be resolved back to +# the store's dict within the same process. + + +class _ManagedStoreDict(dict[str, Buffer]): + """ + A dict subclass that supports weak references. + + Regular dicts don't support weakrefs, but we need to track managed store dicts + in a WeakValueDictionary so they can be garbage collected when no longer + referenced. This subclass adds the necessary __weakref__ slot. + """ + + __slots__ = ("__weakref__",) + + +class _ManagedStoreDictRegistry: + """ + Registry for managed store dicts. + + This registry is the source of truth for managed store dicts. It creates + new dicts, tracks them via weak references, and looks them up by name. + """ + + def __init__(self) -> None: + self._registry: weakref.WeakValueDictionary[str, _ManagedStoreDict] = ( + weakref.WeakValueDictionary() + ) + self._counter = 0 + + def _generate_name(self) -> str: + """Generate a unique name for a store.""" + name = str(self._counter) + self._counter += 1 + return name + + def get_or_create(self, name: str | None = None) -> tuple[_ManagedStoreDict, str]: + """ + Get an existing managed dict by name, or create a new one. + + Parameters + ---------- + name : str | None + The name for the store. If None, a unique name is auto-generated. + If a store with this name already exists, returns the existing store. + Names cannot contain '/' characters. + + Returns + ------- + tuple[_ManagedStoreDict, str] + The store dict and its name. + + Raises + ------ + ValueError + If the name contains '/' characters. + """ + if name is None: + name = self._generate_name() + elif "/" in name: + raise ValueError( + f"Store name cannot contain '/': {name!r}. " + "Use the 'path' parameter to specify a path within the store." + ) + + existing = self._registry.get(name) + if existing is not None: + return existing, name + + store_dict = _ManagedStoreDict() + self._registry[name] = store_dict + return store_dict, name + + def get(self, name: str) -> _ManagedStoreDict | None: + """ + Look up a managed store dict by name. + + Parameters + ---------- + name : str + The name of the store. + + Returns + ------- + _ManagedStoreDict | None + The store dict if found, None otherwise. + """ + return self._registry.get(name) + + def get_from_url(self, url: str) -> tuple[_ManagedStoreDict | None, str, str]: + """ + Look up a managed store dict by its URL. + + Parameters + ---------- + url : str + A URL like "memory://mystore" or "memory://mystore/path/to/node" + + Returns + ------- + tuple[_ManagedStoreDict | None, str, str] + The store dict (if found), the extracted name, and the path. + """ + if not url.startswith("memory://"): + return None, "", "" + + # Parse the store name and path from the URL + # "memory://mystore" -> name="mystore", path="" + # "memory://mystore/path/to/data" -> name="mystore", path="path/to/data" + url_without_scheme = url[len("memory://") :] + parts = url_without_scheme.split("/", 1) + name = parts[0] + path = parts[1] if len(parts) > 1 else "" + + return self._registry.get(name), name, path + + +_managed_store_dict_registry = _ManagedStoreDictRegistry() + + +class ManagedMemoryStore(MemoryStore): + """ + A memory store that owns and manages the lifecycle of its backing dict. + + Unlike ``MemoryStore`` which accepts any ``MutableMapping``, this store + creates and owns its backing dict internally. This enables proper lifecycle + management and allows the store to be looked up by its ``memory://`` URL + within the same process. + + Parameters + ---------- + name : str | None + The name for this store, used in the ``memory://`` URL. If None, a unique + name is auto-generated. If a store with this name already exists, the + new store will share the same backing dict. + path : str + The root path for this store. All keys will be prefixed with this path. + read_only : bool + Whether the store is read-only. + + Attributes + ---------- + name : str + The name of this store. + path : str + The root path of this store. + + Notes + ----- + The backing dict is tracked via weak references and will be garbage collected + when no ``ManagedMemoryStore`` instances reference it. URLs pointing to a + garbage-collected store will fail to resolve. + + See Also + -------- + MemoryStore : A memory store that accepts any MutableMapping. + + Examples + -------- + >>> store = ManagedMemoryStore(name="my-data") + >>> str(store) + 'memory://my-data' + >>> # Later, resolve the URL back to the store's dict + >>> store2 = ManagedMemoryStore.from_url("memory://my-data") + >>> store2._store_dict is store._store_dict + True + >>> # Create a store with a path prefix + >>> store3 = ManagedMemoryStore.from_url("memory://my-data/subdir") + >>> store3.path + 'subdir' + """ + + _store_dict: _ManagedStoreDict + _name: str + path: str + _created_pid: int + + def __init__(self, name: str | None = None, *, path: str = "", read_only: bool = False) -> None: + # Skip MemoryStore.__init__ and call Store.__init__ directly + # because we need to set up _store_dict differently + Store.__init__(self, read_only=read_only) + + # Get or create a managed dict from the registry + self._store_dict, self._name = _managed_store_dict_registry.get_or_create(name) + self.path = normalize_path(path) + self._created_pid = os.getpid() + + def __str__(self) -> str: + return _dereference_path(f"memory://{self._name}", self.path) + + def __repr__(self) -> str: + return f"ManagedMemoryStore('{self}')" + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, type(self)) + and self._store_dict is other._store_dict + and self.path == other.path + and self.read_only == other.read_only + ) + + @property + def name(self) -> str: + """The name of this store, used in the memory:// URL.""" + return self._name + + def _check_same_process(self) -> None: + """Raise an error if this store is being used in a different process.""" + current_pid = os.getpid() + if self._created_pid != current_pid: + raise RuntimeError( + f"ManagedMemoryStore '{self._name}' was created in process {self._created_pid} " + f"but is being used in process {current_pid}. " + "ManagedMemoryStore instances cannot be shared across processes because " + "their backing dict is not serialized. Use a persistent store (e.g., " + "LocalStore, ZipStore) for cross-process data sharing." + ) + + @classmethod + def _from_managed_dict( + cls, + managed_dict: _ManagedStoreDict, + name: str, + *, + path: str = "", + read_only: bool = False, + ) -> ManagedMemoryStore: + """Internal: create a store from an existing managed dict.""" + store = object.__new__(cls) + Store.__init__(store, read_only=read_only) + store._store_dict = managed_dict + store._name = name + store.path = normalize_path(path) + store._created_pid = os.getpid() + return store + + def with_read_only(self, read_only: bool = False) -> ManagedMemoryStore: + # docstring inherited + return type(self)._from_managed_dict( + self._store_dict, self._name, path=self.path, read_only=read_only + ) + + @classmethod + def from_url(cls, url: str, *, read_only: bool = False) -> ManagedMemoryStore: + """ + Create a ManagedMemoryStore from a memory:// URL. + + This looks up the backing dict in the process-wide registry and creates + a new store instance that shares the same dict. + + Parameters + ---------- + url : str + A URL like "memory://my-store" or "memory://my-store/path/to/data" + identifying the store and optional path prefix. + read_only : bool + Whether the new store should be read-only. + + Returns + ------- + ManagedMemoryStore + A store sharing the same backing dict as the original. + + Raises + ------ + ValueError + If the URL is not a valid memory:// URL or the store has been + garbage collected. + """ + managed_dict, name, path = _managed_store_dict_registry.get_from_url(url) + if managed_dict is None: + raise ValueError( + f"Memory store not found for URL '{url}'. " + "The store may have been garbage collected or the URL is invalid." + ) + return cls._from_managed_dict(managed_dict, name, path=path, read_only=read_only) + + # Override MemoryStore methods to use path prefix and check process + + async def get( + self, + key: str, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + # docstring inherited + self._check_same_process() + return await super().get( + _dereference_path(self.path, key), prototype=prototype, byte_range=byte_range + ) + + async def get_partial_values( + self, + prototype: BufferPrototype, + key_ranges: Iterable[tuple[str, ByteRequest | None]], + ) -> list[Buffer | None]: + # docstring inherited + self._check_same_process() + key_ranges = [ + (_dereference_path(self.path, key), byte_range) for key, byte_range in key_ranges + ] + return await super().get_partial_values(prototype, key_ranges) + + async def exists(self, key: str) -> bool: + # docstring inherited + self._check_same_process() + return await super().exists(_dereference_path(self.path, key)) + + async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + # docstring inherited + self._check_same_process() + return await super().set(_dereference_path(self.path, key), value, byte_range=byte_range) + + async def set_if_not_exists(self, key: str, value: Buffer) -> None: + # docstring inherited + self._check_same_process() + return await super().set_if_not_exists(_dereference_path(self.path, key), value) + + async def delete(self, key: str) -> None: + # docstring inherited + self._check_same_process() + return await super().delete(_dereference_path(self.path, key)) + + async def list(self) -> AsyncIterator[str]: + # docstring inherited + self._check_same_process() + prefix = self.path + "/" if self.path else "" + async for key in super().list(): + if key.startswith(prefix): + yield key.removeprefix(prefix) + + async def list_prefix(self, prefix: str) -> AsyncIterator[str]: + # docstring inherited + self._check_same_process() + # Don't use _dereference_path here because it strips trailing slashes, + # which would break prefix matching (e.g., "fo/" vs "foo/") + full_prefix = f"{self.path}/{prefix}" if self.path else prefix + path_prefix = self.path + "/" if self.path else "" + async for key in super().list_prefix(full_prefix): + yield key.removeprefix(path_prefix) + + async def list_dir(self, prefix: str) -> AsyncIterator[str]: + # docstring inherited + self._check_same_process() + full_prefix = _dereference_path(self.path, prefix) + async for key in super().list_dir(full_prefix): + yield key + + def __reduce__( + self, + ) -> tuple[type[ManagedMemoryStore], tuple[str | None], dict[str, Any]]: + """ + Support pickling of ManagedMemoryStore. + + On unpickle, the store will reconnect to an existing store with the same + name if one exists in the registry, or create a new empty store otherwise. + + Note that the backing dict data is NOT serialized - only the store's + identity (name, path, read_only) is preserved. If the original store has + been garbage collected, the unpickled store will have an empty dict. + + The original process ID is preserved so that cross-process usage can be + detected and will raise an error. + """ + return ( + self.__class__, + (self._name,), + { + "path": self.path, + "read_only": self.read_only, + "created_pid": self._created_pid, + }, + ) + + def __setstate__(self, state: dict[str, Any]) -> None: + """Restore state after unpickling.""" + # The __reduce__ method returns (cls, (name,), state) + # Python calls cls(name) then __setstate__(state) + # But __init__ already set up _store_dict and _name from the registry + # We just need to restore path, read_only, and the original process ID + self.path = normalize_path(state.get("path", "")) + self._read_only = state.get("read_only", False) + # Preserve the original process ID to detect cross-process usage + self._created_pid = state.get("created_pid", os.getpid()) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 10ac395b36..78a491de16 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -70,6 +70,33 @@ def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) -> return (start, stop) +def _dereference_path(root: str, path: str) -> str: + """ + Combine a root path with a relative path. + + Parameters + ---------- + root : str + The root path. + path : str + The path relative to root. + + Returns + ------- + str + The combined path with trailing slashes removed. + """ + if not isinstance(root, str): + msg = f"{root=} is not a string ({type(root)=})" # type: ignore[unreachable] + raise TypeError(msg) + if not isinstance(path, str): + msg = f"{path=} is not a string ({type(path)=})" # type: ignore[unreachable] + raise TypeError(msg) + root = root.rstrip("/") + path = f"{root}/{path}" if root else path + return path.rstrip("/") + + def _join_paths(paths: Iterable[str]) -> str: """ Filter out instances of '' and join the remaining strings with '/'. diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 330f220b56..f539f49b0f 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -21,7 +21,7 @@ from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.sync import sync from zarr.storage import MemoryStore, StoreLike -from zarr.storage._common import _dereference_path +from zarr.storage._common import _dereference_path # type: ignore[attr-defined] from zarr.storage._utils import normalize_path from zarr.types import AnyArray diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 03c8b24271..0fc8bfd341 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -12,7 +12,7 @@ from zarr.core.buffer import Buffer, cpu, gpu from zarr.core.sync import sync from zarr.errors import ZarrUserWarning -from zarr.storage import GpuMemoryStore, MemoryStore +from zarr.storage import GpuMemoryStore, ManagedMemoryStore, MemoryStore from zarr.testing.store import StoreTests from zarr.testing.utils import gpu_test @@ -181,3 +181,353 @@ def test_from_dict(self) -> None: result = GpuMemoryStore.from_dict(d) for v in result._store_dict.values(): assert type(v) is gpu.Buffer + + +class TestManagedMemoryStore(StoreTests[ManagedMemoryStore, cpu.Buffer]): + store_cls = ManagedMemoryStore + buffer_cls = cpu.Buffer + + async def set(self, store: ManagedMemoryStore, key: str, value: Buffer) -> None: + store._store_dict[key] = value + + async def get(self, store: ManagedMemoryStore, key: str) -> Buffer: + return store._store_dict[key] + + @pytest.fixture + def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, Any]: + # Use a unique name per test to avoid sharing state between tests + # but ensure the name is deterministic for equality tests + # Replace '/' with '-' since store names cannot contain '/' + sanitized_name = request.node.name.replace("/", "-") + return {"name": f"test-{sanitized_name}"} + + @pytest.fixture + async def store(self, store_kwargs: dict[str, Any]) -> ManagedMemoryStore: + return self.store_cls(**store_kwargs) + + def test_store_repr(self, store: ManagedMemoryStore) -> None: + assert str(store) == f"memory://{store.name}" + + async def test_serializable_store(self, store: ManagedMemoryStore) -> None: + """ + Test pickling semantics for ManagedMemoryStore. + + When pickled and unpickled within the same process (where the original + store still exists in the registry), the unpickled store reconnects to + the same backing dict. + """ + import pickle + + # Add some data to the store + await store.set("test-key", self.buffer_cls.from_bytes(b"test-value")) + + # Pickle and unpickle the store + pickled = pickle.dumps(store) + store2 = pickle.loads(pickled) + + # The unpickled store should reconnect to the same backing dict + assert store2._store_dict is store._store_dict + assert store2.name == store.name + assert store2.path == store.path + assert store2.read_only == store.read_only + + # The data should be accessible + result = await store2.get("test-key") + assert result is not None + assert result.to_bytes() == b"test-value" + + async def test_pickle_with_path(self) -> None: + """Test that path is preserved through pickle round-trip.""" + import pickle + + store = ManagedMemoryStore(name="pickle-path-test", path="some/path") + await store.set("key", self.buffer_cls.from_bytes(b"value")) + + pickled = pickle.dumps(store) + store2 = pickle.loads(pickled) + + assert store2.path == "some/path" + assert store2._store_dict is store._store_dict + + # Check that operations use the path correctly + result = await store2.get("key") + assert result is not None + assert result.to_bytes() == b"value" + + def test_pickle_after_gc(self) -> None: + """ + Test that unpickling after the original store is garbage collected + creates a new empty store with the same name (in the same process). + """ + import gc + import pickle + + # Create a store with a unique name and pickle it + store = ManagedMemoryStore(name="gc-pickle-test") + store._store_dict["key"] = self.buffer_cls.from_bytes(b"value") + pickled = pickle.dumps(store) + + # Delete the store and garbage collect + del store + gc.collect() + + # Unpickling should create a new store with an empty dict + store2 = pickle.loads(pickled) + assert store2.name == "gc-pickle-test" + # The dict is empty because the original was garbage collected + assert len(store2._store_dict) == 0 + + async def test_cross_process_detection(self) -> None: + """ + Test that using a ManagedMemoryStore in a different process raises an error. + + This prevents silent data loss when a store is pickled and unpickled + in a different process (e.g., with multiprocessing). + """ + import pickle + + store = ManagedMemoryStore(name="cross-process-test") + await store.set("key", self.buffer_cls.from_bytes(b"value")) + + # Simulate unpickling in a different process by manipulating _created_pid + pickled = pickle.dumps(store) + store2 = pickle.loads(pickled) + + # Manually change the created_pid to simulate a different process + store2._created_pid = store2._created_pid + 1 + + # All operations should raise RuntimeError + with pytest.raises(RuntimeError, match="was created in process"): + await store2.get("key") + + with pytest.raises(RuntimeError, match="was created in process"): + await store2.set("key", self.buffer_cls.from_bytes(b"value")) + + with pytest.raises(RuntimeError, match="was created in process"): + await store2.exists("key") + + with pytest.raises(RuntimeError, match="was created in process"): + await store2.delete("key") + + with pytest.raises(RuntimeError, match="was created in process"): + [k async for k in store2.list()] + + with pytest.raises(RuntimeError, match="was created in process"): + [k async for k in store2.list_prefix("")] + + with pytest.raises(RuntimeError, match="was created in process"): + [k async for k in store2.list_dir("")] + + def test_store_supports_writes(self, store: ManagedMemoryStore) -> None: + assert store.supports_writes + + def test_store_supports_listing(self, store: ManagedMemoryStore) -> None: + assert store.supports_listing + + async def test_list_prefix(self, store: MemoryStore) -> None: + assert True + + @pytest.mark.parametrize("dtype", ["uint8", "float32", "int64"]) + @pytest.mark.parametrize("zarr_format", [2, 3]) + async def test_deterministic_size( + self, store: MemoryStore, dtype: npt.DTypeLike, zarr_format: ZarrFormat + ) -> None: + a = zarr.empty( + store=store, + shape=(3,), + chunks=(1000,), + dtype=dtype, + zarr_format=zarr_format, + overwrite=True, + ) + a[...] = 1 + a.resize((1000,)) + + np.testing.assert_array_equal(a[:3], 1) + np.testing.assert_array_equal(a[3:], 0) + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: ManagedMemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + result = await store._get_bytes(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: ManagedMemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + result = store._get_bytes_sync(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: ManagedMemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + result = await store._get_json(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: ManagedMemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + result = store._get_json_sync(key, prototype=buffer_cls) + assert result == data + + def test_from_url(self, store: ManagedMemoryStore) -> None: + """Test that from_url creates a store sharing the same dict.""" + url = str(store) + store2 = ManagedMemoryStore.from_url(url) + assert store2._store_dict is store._store_dict + + def test_from_url_with_path(self, store: ManagedMemoryStore) -> None: + """Test that from_url extracts path component from URL.""" + url = str(store) + "/some/path" + store2 = ManagedMemoryStore.from_url(url) + assert store2._store_dict is store._store_dict + assert store2.path == "some/path" + assert str(store2) == url + + def test_from_url_invalid(self) -> None: + """Test that from_url raises ValueError for non-existent store.""" + with pytest.raises(ValueError, match="Memory store not found"): + ManagedMemoryStore.from_url("memory://nonexistent-store") + + def test_from_url_not_memory_scheme(self) -> None: + """Test that from_url raises ValueError for non-memory URLs.""" + with pytest.raises(ValueError, match="Memory store not found"): + ManagedMemoryStore.from_url("file:///tmp/test") + + def test_named_store(self) -> None: + """Test that stores can be created with explicit names.""" + store = ManagedMemoryStore(name="my-test-store") + assert store.name == "my-test-store" + assert str(store) == "memory://my-test-store" + + def test_named_store_shares_dict(self) -> None: + """Test that creating a store with the same name shares the dict.""" + store1 = ManagedMemoryStore(name="shared-store") + store2 = ManagedMemoryStore(name="shared-store") + assert store1._store_dict is store2._store_dict + assert store1.name == store2.name + + def test_auto_generated_name(self) -> None: + """Test that stores get auto-generated names when none provided.""" + store = ManagedMemoryStore() + assert store.name is not None + assert str(store) == f"memory://{store.name}" + + def test_with_read_only_shares_dict(self, store: ManagedMemoryStore) -> None: + """Test that with_read_only creates a store sharing the same dict.""" + store2 = store.with_read_only(True) + assert store2._store_dict is store._store_dict + assert store2.read_only is True + assert store.read_only is False + + def test_with_read_only_preserves_path(self) -> None: + """Test that with_read_only preserves the path.""" + store = ManagedMemoryStore(name="path-test", path="some/path") + store2 = store.with_read_only(True) + assert store2.path == "some/path" + assert store2._store_dict is store._store_dict + + async def test_path_prefix_operations(self) -> None: + """Test that store operations use the path prefix correctly.""" + store = ManagedMemoryStore(name="prefix-test") + store_with_path = ManagedMemoryStore.from_url("memory://prefix-test/subdir") + + # Write via store_with_path + await store_with_path.set("key", self.buffer_cls.from_bytes(b"value")) + + # The key should be stored with the prefix in the underlying dict + assert "subdir/key" in store._store_dict + assert "key" not in store._store_dict + + # Read via store_with_path should work + result = await store_with_path.get("key") + assert result is not None + assert result.to_bytes() == b"value" + + # Read via store without path should use full key + result2 = await store.get("subdir/key") + assert result2 is not None + assert result2.to_bytes() == b"value" + + async def test_path_list_operations(self) -> None: + """Test that list operations filter by path prefix.""" + store = ManagedMemoryStore(name="list-test") + + # Set up some keys at different paths + await store.set("a/key1", self.buffer_cls.from_bytes(b"v1")) + await store.set("a/key2", self.buffer_cls.from_bytes(b"v2")) + await store.set("b/key3", self.buffer_cls.from_bytes(b"v3")) + + # Create a store with path "a" + store_a = ManagedMemoryStore.from_url("memory://list-test/a") + + # list() should only return keys under "a", without the "a/" prefix + keys = [k async for k in store_a.list()] + assert sorted(keys) == ["key1", "key2"] + + async def test_path_exists(self) -> None: + """Test that exists() uses the path prefix.""" + store = ManagedMemoryStore(name="exists-test") + await store.set("prefix/key", self.buffer_cls.from_bytes(b"value")) + + store_with_path = ManagedMemoryStore.from_url("memory://exists-test/prefix") + assert await store_with_path.exists("key") + assert not await store_with_path.exists("prefix/key") + + def test_path_normalization(self) -> None: + """Test that paths are normalized.""" + store1 = ManagedMemoryStore(name="norm-test", path="a/b/") + store2 = ManagedMemoryStore(name="norm-test", path="/a/b") + store3 = ManagedMemoryStore(name="norm-test", path="a//b") + assert store1.path == "a/b" + assert store2.path == "a/b" + assert store3.path == "a/b" + + def test_name_cannot_contain_slash(self) -> None: + """Test that store names cannot contain '/'.""" + with pytest.raises(ValueError, match="cannot contain '/'"): + ManagedMemoryStore(name="foo/bar") + + def test_garbage_collection(self) -> None: + """Test that the dict is garbage collected when no stores reference it.""" + import gc + + store = ManagedMemoryStore() + url = str(store) + + # URL should resolve while store exists + store2 = ManagedMemoryStore.from_url(url) + assert store2._store_dict is store._store_dict + + # Delete both stores + del store + del store2 + gc.collect() + + # URL should no longer resolve + with pytest.raises(ValueError, match="garbage collected"): + ManagedMemoryStore.from_url(url)