From 021bd440c3ee2fae13aaa255d722bb582d22b6a0 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 17:29:56 +0100 Subject: [PATCH 01/10] add store routines for getting bytes and json --- src/zarr/abc/store.py | 215 ++++++++++++++++++++++++++++- src/zarr/storage/_common.py | 107 +++++++++++++++ src/zarr/storage/_local.py | 232 +++++++++++++++++++++++++++++++- src/zarr/storage/_memory.py | 232 +++++++++++++++++++++++++++++++- src/zarr/testing/store.py | 41 +++++- tests/test_store/test_local.py | 82 +++++++++++ tests/test_store/test_memory.py | 70 ++++++++++ 7 files changed, 974 insertions(+), 5 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4b3edf78d1..7d0589c836 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,11 +1,14 @@ from __future__ import annotations +import asyncio +import json from abc import ABC, abstractmethod -from asyncio import gather from dataclasses import dataclass from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from zarr.core.sync import sync + if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Iterable from types import TracebackType @@ -206,6 +209,214 @@ async def get( """ ... + async def get_bytes_async( + self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> bytes: + """ + Retrieve raw bytes from the store asynchronously. + + This is a convenience method that wraps ``get()`` and converts the result + to bytes. Use this when you need the raw byte content of a stored value. + + Parameters + ---------- + key : str + The key identifying the data to retrieve. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + get : Lower-level method that returns a Buffer object. + get_bytes : Synchronous version of this method. + get_json_async : Asynchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = await store.get_bytes_async("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + buffer = await self.get(key, prototype, byte_range) + if buffer is None: + raise FileNotFoundError(key) + return buffer.to_bytes() + + def get_bytes( + self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> bytes: + """ + Retrieve raw bytes from the store synchronously. + + This is a synchronous wrapper around ``get_bytes_async()``. It should only + be called from non-async code. For async contexts, use ``get_bytes_async()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes_async()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_bytes_async : Asynchronous version of this method. + get_json : Synchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = MemoryStore() + >>> store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = store.get_bytes("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + + return sync(self.get_bytes_async(key, prototype=prototype, byte_range=byte_range)) + + async def get_json_async( + self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Any: + """ + Retrieve and parse JSON data from the store asynchronously. + + This is a convenience method that retrieves bytes from the store and + parses them as JSON. Commonly used for reading Zarr metadata files + like ``zarr.json``. + + Parameters + ---------- + key : str + The key identifying the JSON data to retrieve. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + get_bytes_async : Method for retrieving raw bytes without parsing. + get_json : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = await store.get_json_async("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return json.loads( + await self.get_bytes_async(key, prototype=prototype, byte_range=byte_range) + ) + + def get_json( + self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Any: + """ + Retrieve and parse JSON data from the store synchronously. + + This is a synchronous wrapper around ``get_json_async()``. It should only + be called from non-async code. For async contexts, use ``get_json_async()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json_async()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_json_async : Asynchronous version of this method. + get_bytes : Synchronous method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = store.get_json("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return sync(self.get_json_async(key, prototype=prototype, byte_range=byte_range)) + @abstractmethod async def get_partial_values( self, @@ -278,7 +489,7 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ Insert multiple (key, value) pairs into storage. """ - await gather(*starmap(self.set, values)) + await asyncio.gather(*starmap(self.set, values)) @property def supports_consolidated_metadata(self) -> bool: diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index d762097cc3..b8dcaff3be 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -228,6 +228,113 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) + async def get_bytes_async( + self, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the store path asynchronously. + + This is a convenience method that wraps ``get()`` and converts the result + to bytes. The ``prototype`` parameter is optional and defaults to the + standard buffer prototype. + + Parameters + ---------- + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at this path. + + Raises + ------ + FileNotFoundError + If the path does not exist in the store. + + See Also + -------- + get : Lower-level method that returns a Buffer object. + get_json_async : Asynchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> path = StorePath(store, "data") + >>> await path.set(Buffer.from_bytes(b"hello world")) + >>> data = await path.get_bytes_async() + >>> print(data) + b'hello world' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await self.store.get_bytes_async( + self.path, prototype=prototype, byte_range=byte_range + ) + + async def get_json_async( + self, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the store path asynchronously. + + This is a convenience method that retrieves bytes from the store and + parses them as JSON. The ``prototype`` parameter is optional and defaults + to the standard buffer prototype. + + Parameters + ---------- + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the path does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + get_bytes_async : Method for retrieving raw bytes without parsing. + get : Lower-level method that returns a Buffer object. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> path = StorePath(store, "zarr.json") + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await path.set(Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = await path.get_json_async() + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await self.store.get_json_async( + self.path, prototype=prototype, byte_range=byte_range + ) + def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f64da71bb4..13c86a2f22 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -8,7 +8,7 @@ import sys import uuid from pathlib import Path -from typing import TYPE_CHECKING, BinaryIO, Literal, Self +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( ByteRequest, @@ -306,6 +306,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: except (FileNotFoundError, NotADirectoryError): pass + async def get_bytes_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes_async : Base implementation with full documentation. + get_bytes : Synchronous version of this method. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_bytes_async("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + + def get_bytes( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes_async()`` instead. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_async : Asynchronous version of this method. + + Examples + -------- + >>> store = LocalStore("data") + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + async def get_json_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json_async : Base implementation with full documentation. + get_json : Synchronous version of this method. + get_bytes_async : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_json_async("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + + def get_json( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json_async()`` instead. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_async : Asynchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = LocalStore("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_json(key, prototype=prototype, byte_range=byte_range) + async def move(self, dest_root: Path | str) -> None: """ Move the store to another path. The old root directory is deleted. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 904be922d7..b56771f62a 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,7 +1,7 @@ from __future__ import annotations from logging import getLogger -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING, Any, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu @@ -175,6 +175,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: for key in keys_unique: yield key + async def get_bytes_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes_async : Base implementation with full documentation. + get_bytes : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_bytes_async("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + + def get_bytes( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes_async()`` instead. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_async : Asynchronous version of this method. + + Examples + -------- + >>> store = MemoryStore() + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + async def get_json_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json_async : Base implementation with full documentation. + get_json : Synchronous version of this method. + get_bytes_async : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_json_async("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + + def get_json( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json_async()`` instead. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_async : Asynchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_json(key, prototype=prototype, byte_range=byte_range) + class GpuMemoryStore(MemoryStore): """ diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ad3b80da41..bee28639a2 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import pickle from abc import abstractmethod from typing import TYPE_CHECKING, Generic, TypeVar @@ -23,7 +24,7 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, default_buffer_prototype -from zarr.core.sync import _collect_aiterator +from zarr.core.sync import _collect_aiterator, sync from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal @@ -526,6 +527,44 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new + async def test_get_bytes_async(self, store: S) -> None: + """ + Test that the get_bytes_async method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + assert await store.get_bytes_async(key, prototype=default_buffer_prototype()) == data + + def test_get_bytes_sync(self, store: S) -> None: + """ + Test that the get_bytes method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + assert store.get_bytes(key, prototype=default_buffer_prototype()) == data + + async def test_get_json_async(self, store: S) -> None: + """ + Test that the get_bytes_async method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) + assert await store.get_json_async(key, prototype=default_buffer_prototype()) == data + + def test_get_json_sync(self, store: S) -> None: + """ + Test that the get_json method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) + assert store.get_json(key, prototype=default_buffer_prototype()) == data + class LatencyStore(WrapperStore[Store]): """ diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 6756bc83d9..35d48e3f95 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -150,3 +150,85 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files + + +async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_bytes_async works with prototype=None.""" + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + + store = await LocalStore.open(root=tmp_path) + data = b"hello world" + key = "test_key" + await store.set(key, cpu.Buffer.from_bytes(data)) + + # Test with None (default) + result_none = await store.get_bytes_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + +def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_bytes works with prototype=None.""" + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + store = sync(LocalStore.open(root=tmp_path)) + data = b"hello world" + key = "test_key" + sync(store.set(key, cpu.Buffer.from_bytes(data))) + + # Test with None (default) + result_none = store.get_bytes(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + assert result_proto == data + + +async def test_get_json_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_json_async works with prototype=None.""" + import json + + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + + store = await LocalStore.open(root=tmp_path) + data = {"foo": "bar", "number": 42} + key = "test.json" + await store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode())) + + # Test with None (default) + result_none = await store.get_json_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + +def test_get_json_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_json works with prototype=None.""" + import json + + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + store = sync(LocalStore.open(root=tmp_path)) + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode()))) + + # Test with None (default) + result_none = store.get_json(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_json(key, prototype=default_buffer_prototype()) + assert result_proto == data diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 29fa9b2964..b56d9933d4 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -76,6 +76,76 @@ async def test_deterministic_size( np.testing.assert_array_equal(a[:3], 1) np.testing.assert_array_equal(a[3:], 0) + async def test_get_bytes_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_bytes_async works with prototype=None.""" + from zarr.core.buffer.core import default_buffer_prototype + + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + # Test with None (default) + result_none = await store.get_bytes_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_bytes works with prototype=None.""" + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + # Test with None (default) + result_none = store.get_bytes(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + assert result_proto == data + + async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_json_async works with prototype=None.""" + import json + + from zarr.core.buffer.core import default_buffer_prototype + + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + # Test with None (default) + result_none = await store.get_json_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_json works with prototype=None.""" + import json + + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + # Test with None (default) + result_none = store.get_json(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_json(key, prototype=default_buffer_prototype()) + assert result_proto == data + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning") From 7d26b8ee4d33c6c784e42784e1acf37c5836dd8e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 18:02:29 +0100 Subject: [PATCH 02/10] check for FileNotFoundError when a key is missing --- src/zarr/testing/store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index bee28639a2..30ff376fb0 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -535,6 +535,8 @@ async def test_get_bytes_async(self, store: S) -> None: key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data)) assert await store.get_bytes_async(key, prototype=default_buffer_prototype()) == data + with pytest.raises(FileNotFoundError): + await store.get_bytes_async("nonexistent_key", prototype=default_buffer_prototype()) def test_get_bytes_sync(self, store: S) -> None: """ From 971c3e4fb6dd6c895e49c2763be5c1b9164b9114 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 18:02:39 +0100 Subject: [PATCH 03/10] remove storepath methods --- src/zarr/storage/_common.py | 107 ------------------------------------ 1 file changed, 107 deletions(-) diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index b8dcaff3be..d762097cc3 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -228,113 +228,6 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) - async def get_bytes_async( - self, - prototype: BufferPrototype | None = None, - byte_range: ByteRequest | None = None, - ) -> bytes: - """ - Retrieve raw bytes from the store path asynchronously. - - This is a convenience method that wraps ``get()`` and converts the result - to bytes. The ``prototype`` parameter is optional and defaults to the - standard buffer prototype. - - Parameters - ---------- - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses - ``default_buffer_prototype()``. - byte_range : ByteRequest, optional - If specified, only retrieve a portion of the stored data. - Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. - - Returns - ------- - bytes - The raw bytes stored at this path. - - Raises - ------ - FileNotFoundError - If the path does not exist in the store. - - See Also - -------- - get : Lower-level method that returns a Buffer object. - get_json_async : Asynchronous method for retrieving and parsing JSON data. - - Examples - -------- - >>> store = await MemoryStore.open() - >>> path = StorePath(store, "data") - >>> await path.set(Buffer.from_bytes(b"hello world")) - >>> data = await path.get_bytes_async() - >>> print(data) - b'hello world' - """ - if prototype is None: - prototype = default_buffer_prototype() - return await self.store.get_bytes_async( - self.path, prototype=prototype, byte_range=byte_range - ) - - async def get_json_async( - self, - prototype: BufferPrototype | None = None, - byte_range: ByteRequest | None = None, - ) -> Any: - """ - Retrieve and parse JSON data from the store path asynchronously. - - This is a convenience method that retrieves bytes from the store and - parses them as JSON. The ``prototype`` parameter is optional and defaults - to the standard buffer prototype. - - Parameters - ---------- - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses - ``default_buffer_prototype()``. - byte_range : ByteRequest, optional - If specified, only retrieve a portion of the stored data. - Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. - Note: Using byte ranges with JSON may result in invalid JSON. - - Returns - ------- - Any - The parsed JSON data. This follows the behavior of ``json.loads()`` and - can be any JSON-serializable type: dict, list, str, int, float, bool, or None. - - Raises - ------ - FileNotFoundError - If the path does not exist in the store. - json.JSONDecodeError - If the stored data is not valid JSON. - - See Also - -------- - get_bytes_async : Method for retrieving raw bytes without parsing. - get : Lower-level method that returns a Buffer object. - - Examples - -------- - >>> store = await MemoryStore.open() - >>> path = StorePath(store, "zarr.json") - >>> metadata = {"zarr_format": 3, "node_type": "array"} - >>> await path.set(Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = await path.get_json_async() - >>> print(data) - {'zarr_format': 3, 'node_type': 'array'} - """ - if prototype is None: - prototype = default_buffer_prototype() - return await self.store.get_json_async( - self.path, prototype=prototype, byte_range=byte_range - ) - def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) From d70a5e5277ca7a225d9fb0ac22945b5e1fcb8cc3 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 18:24:39 +0100 Subject: [PATCH 04/10] changelog --- changes/3638.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3638.feature.md diff --git a/changes/3638.feature.md b/changes/3638.feature.md new file mode 100644 index 0000000000..ad2276fd51 --- /dev/null +++ b/changes/3638.feature.md @@ -0,0 +1 @@ +Add methods for reading stored objects as bytes and JSON-decoded bytes to store classes. \ No newline at end of file From a21305887b7287d957811afb0a8ec8f894d8cfc7 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 10:37:23 +0100 Subject: [PATCH 05/10] rename methods --- src/zarr/abc/store.py | 45 +++++++++++++------------- src/zarr/storage/_local.py | 26 +++++++-------- src/zarr/storage/_memory.py | 26 +++++++-------- src/zarr/testing/store.py | 12 +++---- tests/test_store/test_local.py | 56 ++++++++++++++------------------- tests/test_store/test_memory.py | 18 +++++------ 6 files changed, 85 insertions(+), 98 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 7d0589c836..e685e4b3b0 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -209,7 +209,7 @@ async def get( """ ... - async def get_bytes_async( + async def get_bytes( self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> bytes: """ @@ -242,7 +242,7 @@ async def get_bytes_async( -------- get : Lower-level method that returns a Buffer object. get_bytes : Synchronous version of this method. - get_json_async : Asynchronous method for retrieving and parsing JSON data. + get_json : Asynchronous method for retrieving and parsing JSON data. Examples -------- @@ -257,7 +257,7 @@ async def get_bytes_async( raise FileNotFoundError(key) return buffer.to_bytes() - def get_bytes( + def get_bytes_sync( self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> bytes: """ @@ -289,7 +289,7 @@ def get_bytes( Warnings -------- - Do not call this method from async functions. Use ``get_bytes_async()`` instead + Do not call this method from async functions. Use ``get_bytes()`` instead to avoid blocking the event loop. See Also @@ -300,23 +300,22 @@ def get_bytes( Examples -------- >>> store = MemoryStore() - >>> store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = store.get_bytes("data", prototype=default_buffer_prototype()) + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = store.get_bytes_sync("data", prototype=default_buffer_prototype()) >>> print(data) b'hello world' """ - return sync(self.get_bytes_async(key, prototype=prototype, byte_range=byte_range)) + return sync(self.get_bytes(key, prototype=prototype, byte_range=byte_range)) - async def get_json_async( + async def get_json( self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Any: """ Retrieve and parse JSON data from the store asynchronously. This is a convenience method that retrieves bytes from the store and - parses them as JSON. Commonly used for reading Zarr metadata files - like ``zarr.json``. + parses them as JSON. Parameters ---------- @@ -344,31 +343,29 @@ async def get_json_async( See Also -------- - get_bytes_async : Method for retrieving raw bytes without parsing. - get_json : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes. + get_json_sync : Synchronous version of this method. Examples -------- >>> store = await MemoryStore.open() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = await store.get_json_async("zarr.json", prototype=default_buffer_prototype()) + >>> data = await store.get_json("zarr.json", prototype=default_buffer_prototype()) >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ - return json.loads( - await self.get_bytes_async(key, prototype=prototype, byte_range=byte_range) - ) + return json.loads(await self.get_bytes(key, prototype=prototype, byte_range=byte_range)) - def get_json( + def get_json_sync( self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Any: """ Retrieve and parse JSON data from the store synchronously. - This is a synchronous wrapper around ``get_json_async()``. It should only - be called from non-async code. For async contexts, use ``get_json_async()`` + This is a synchronous wrapper around ``get_json()``. It should only + be called from non-async code. For async contexts, use ``get_json()`` instead. Parameters @@ -397,25 +394,25 @@ def get_json( Warnings -------- - Do not call this method from async functions. Use ``get_json_async()`` instead + Do not call this method from async functions. Use ``get_json()`` instead to avoid blocking the event loop. See Also -------- - get_json_async : Asynchronous version of this method. - get_bytes : Synchronous method for retrieving raw bytes without parsing. + get_json : Asynchronous version of this method. + get_bytes_sync : Synchronous method for retrieving raw bytes without parsing. Examples -------- >>> store = MemoryStore() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = store.get_json("zarr.json", prototype=default_buffer_prototype()) + >>> data = store.get_json_sync("zarr.json", prototype=default_buffer_prototype()) >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ - return sync(self.get_json_async(key, prototype=prototype, byte_range=byte_range)) + return sync(self.get_json(key, prototype=prototype, byte_range=byte_range)) @abstractmethod async def get_partial_values( diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 13c86a2f22..08681f2630 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -306,7 +306,7 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: except (FileNotFoundError, NotADirectoryError): pass - async def get_bytes_async( + async def get_bytes( self, key: str = "", *, @@ -356,9 +356,9 @@ async def get_bytes_async( """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) - def get_bytes( + def get_bytes_sync( self, key: str = "", *, @@ -412,9 +412,9 @@ def get_bytes( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) - async def get_json_async( + async def get_json( self, key: str = "", *, @@ -425,7 +425,7 @@ async def get_json_async( Retrieve and parse JSON data from the local store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` for full documentation. Parameters @@ -454,7 +454,7 @@ async def get_json_async( See Also -------- - Store.get_json_async : Base implementation with full documentation. + Store.get_json : Base implementation with full documentation. get_json : Synchronous version of this method. get_bytes_async : Method for retrieving raw bytes without parsing. @@ -465,15 +465,15 @@ async def get_json_async( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for LocalStore - >>> data = await store.get_json_async("zarr.json") + >>> data = await store.get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_json(key, prototype=prototype, byte_range=byte_range) - def get_json( + def get_json_sync( self, key: str = "", *, @@ -513,12 +513,12 @@ def get_json( Warnings -------- - Do not call this method from async functions. Use ``get_json_async()`` instead. + Do not call this method from async functions. Use ``get_json()`` instead. See Also -------- Store.get_json : Base implementation with full documentation. - get_json_async : Asynchronous version of this method. + get_json : Asynchronous version of this method. get_bytes : Method for retrieving raw bytes without parsing. Examples @@ -534,7 +534,7 @@ def get_json( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_json(key, prototype=prototype, byte_range=byte_range) + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) async def move(self, dest_root: Path | str) -> None: """ diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index b56771f62a..5a2593eb25 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -175,7 +175,7 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: for key in keys_unique: yield key - async def get_bytes_async( + async def get_bytes( self, key: str = "", *, @@ -225,9 +225,9 @@ async def get_bytes_async( """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) - def get_bytes( + def get_bytes_sync( self, key: str = "", *, @@ -281,9 +281,9 @@ def get_bytes( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) - async def get_json_async( + async def get_json( self, key: str = "", *, @@ -294,7 +294,7 @@ async def get_json_async( Retrieve and parse JSON data from the memory store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` for full documentation. Parameters @@ -323,7 +323,7 @@ async def get_json_async( See Also -------- - Store.get_json_async : Base implementation with full documentation. + Store.get_json : Base implementation with full documentation. get_json : Synchronous version of this method. get_bytes_async : Method for retrieving raw bytes without parsing. @@ -334,15 +334,15 @@ async def get_json_async( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for MemoryStore - >>> data = await store.get_json_async("zarr.json") + >>> data = await store.get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_json(key, prototype=prototype, byte_range=byte_range) - def get_json( + def get_json_sync( self, key: str = "", *, @@ -382,12 +382,12 @@ def get_json( Warnings -------- - Do not call this method from async functions. Use ``get_json_async()`` instead. + Do not call this method from async functions. Use ``get_json()`` instead. See Also -------- Store.get_json : Base implementation with full documentation. - get_json_async : Asynchronous version of this method. + get_json : Asynchronous version of this method. get_bytes : Method for retrieving raw bytes without parsing. Examples @@ -403,7 +403,7 @@ def get_json( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_json(key, prototype=prototype, byte_range=byte_range) + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) class GpuMemoryStore(MemoryStore): diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 30ff376fb0..f0cb6dd48f 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -534,9 +534,9 @@ async def test_get_bytes_async(self, store: S) -> None: data = b"hello world" key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data)) - assert await store.get_bytes_async(key, prototype=default_buffer_prototype()) == data + assert await store.get_bytes(key, prototype=default_buffer_prototype()) == data with pytest.raises(FileNotFoundError): - await store.get_bytes_async("nonexistent_key", prototype=default_buffer_prototype()) + await store.get_bytes("nonexistent_key", prototype=default_buffer_prototype()) def test_get_bytes_sync(self, store: S) -> None: """ @@ -545,9 +545,9 @@ def test_get_bytes_sync(self, store: S) -> None: data = b"hello world" key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) - assert store.get_bytes(key, prototype=default_buffer_prototype()) == data + assert store.get_bytes_sync(key, prototype=default_buffer_prototype()) == data - async def test_get_json_async(self, store: S) -> None: + async def test_get_json(self, store: S) -> None: """ Test that the get_bytes_async method reads json. """ @@ -555,7 +555,7 @@ async def test_get_json_async(self, store: S) -> None: data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) - assert await store.get_json_async(key, prototype=default_buffer_prototype()) == data + assert await store.get_json(key, prototype=default_buffer_prototype()) == data def test_get_json_sync(self, store: S) -> None: """ @@ -565,7 +565,7 @@ def test_get_json_sync(self, store: S) -> None: data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) - assert store.get_json(key, prototype=default_buffer_prototype()) == data + assert store.get_json_sync(key, prototype=default_buffer_prototype()) == data class LatencyStore(WrapperStore[Store]): diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 35d48e3f95..d6a97110ad 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import pathlib import re @@ -9,6 +10,8 @@ import zarr from zarr import create_array from zarr.core.buffer import Buffer, cpu +from zarr.core.buffer.core import BufferPrototype, default_buffer_prototype +from zarr.core.sync import sync from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests @@ -153,9 +156,8 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes_async works with prototype=None.""" + """Test that get_bytes works with prototype=None.""" from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype store = await LocalStore.open(root=tmp_path) data = b"hello world" @@ -163,18 +165,17 @@ async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: await store.set(key, cpu.Buffer.from_bytes(data)) # Test with None (default) - result_none = await store.get_bytes_async(key) + result_none = await store.get_bytes(key) assert result_none == data # Test with explicit prototype - result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) assert result_proto == data def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes works with prototype=None.""" + """Test that get_bytes_sync works with prototype=None.""" from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype from zarr.core.sync import sync store = sync(LocalStore.open(root=tmp_path)) @@ -183,20 +184,19 @@ def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: sync(store.set(key, cpu.Buffer.from_bytes(data))) # Test with None (default) - result_none = store.get_bytes(key) + result_none = store.get_bytes_sync(key) assert result_none == data # Test with explicit prototype - result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) assert result_proto == data -async def test_get_json_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_json_async works with prototype=None.""" - import json - - from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype +@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) +async def test_get_json_with_prototype_none( + tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype +) -> None: + """Test that get_json works with prototype=None.""" store = await LocalStore.open(root=tmp_path) data = {"foo": "bar", "number": 42} @@ -204,21 +204,15 @@ async def test_get_json_with_prototype_none(tmp_path: pathlib.Path) -> None: await store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode())) # Test with None (default) - result_none = await store.get_json_async(key) - assert result_none == data + result = await store.get_json(key, prototype=buffer_cls) + assert result == data - # Test with explicit prototype - result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) - assert result_proto == data - -def test_get_json_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_json works with prototype=None.""" - import json - - from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype - from zarr.core.sync import sync +@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) +def test_get_json_sync_with_prototype( + tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype +) -> None: + """Test that get_json_sync works with prototype=None.""" store = sync(LocalStore.open(root=tmp_path)) data = {"foo": "bar", "number": 42} @@ -226,9 +220,5 @@ def test_get_json_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: sync(store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode()))) # Test with None (default) - result_none = store.get_json(key) - assert result_none == data - - # Test with explicit prototype - result_proto = store.get_json(key, prototype=default_buffer_prototype()) - assert result_proto == data + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index b56d9933d4..c47c1adb12 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -85,11 +85,11 @@ async def test_get_bytes_with_prototype_none(self, store: MemoryStore) -> None: await self.set(store, key, self.buffer_cls.from_bytes(data)) # Test with None (default) - result_none = await store.get_bytes_async(key) + result_none = await store.get_bytes(key) assert result_none == data # Test with explicit prototype - result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) assert result_proto == data def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: @@ -102,15 +102,15 @@ def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: sync(self.set(store, key, self.buffer_cls.from_bytes(data))) # Test with None (default) - result_none = store.get_bytes(key) + result_none = store.get_bytes_sync(key) assert result_none == data # Test with explicit prototype - result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) assert result_proto == data async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_json_async works with prototype=None.""" + """Test that get_json works with prototype=None.""" import json from zarr.core.buffer.core import default_buffer_prototype @@ -120,11 +120,11 @@ async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) # Test with None (default) - result_none = await store.get_json_async(key) + result_none = await store.get_json(key) assert result_none == data # Test with explicit prototype - result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) + result_proto = await store.get_json(key, prototype=default_buffer_prototype()) assert result_proto == data def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: @@ -139,11 +139,11 @@ def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) # Test with None (default) - result_none = store.get_json(key) + result_none = store.get_json_sync(key) assert result_none == data # Test with explicit prototype - result_proto = store.get_json(key, prototype=default_buffer_prototype()) + result_proto = store.get_json_sync(key, prototype=default_buffer_prototype()) assert result_proto == data From 38ff5172cc126d15abfa992ceee9cba81f7e9e3e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 11:02:07 +0100 Subject: [PATCH 06/10] continue renaming / test refactoring --- src/zarr/abc/store.py | 10 ++--- src/zarr/storage/_local.py | 22 ++++----- src/zarr/storage/_memory.py | 22 ++++----- src/zarr/testing/store.py | 8 ++-- tests/test_store/test_memory.py | 79 +++++++++++++-------------------- 5 files changed, 61 insertions(+), 80 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index e685e4b3b0..0e98777ff5 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -248,7 +248,7 @@ async def get_bytes( -------- >>> store = await MemoryStore.open() >>> await store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = await store.get_bytes_async("data", prototype=default_buffer_prototype()) + >>> data = await store.get_bytes("data", prototype=default_buffer_prototype()) >>> print(data) b'hello world' """ @@ -263,8 +263,8 @@ def get_bytes_sync( """ Retrieve raw bytes from the store synchronously. - This is a synchronous wrapper around ``get_bytes_async()``. It should only - be called from non-async code. For async contexts, use ``get_bytes_async()`` + This is a synchronous wrapper around ``get_bytes()``. It should only + be called from non-async code. For async contexts, use ``get_bytes()`` instead. Parameters @@ -294,8 +294,8 @@ def get_bytes_sync( See Also -------- - get_bytes_async : Asynchronous version of this method. - get_json : Synchronous method for retrieving and parsing JSON data. + get_bytes : Asynchronous version of this method. + get_json_sync : Synchronous method for retrieving and parsing JSON data. Examples -------- diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 08681f2630..9fb3f8b6ad 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -317,7 +317,7 @@ async def get_bytes( Retrieve raw bytes from the local store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` for full documentation. Parameters @@ -342,15 +342,15 @@ async def get_bytes( See Also -------- - Store.get_bytes_async : Base implementation with full documentation. - get_bytes : Synchronous version of this method. + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. Examples -------- >>> store = await LocalStore.open("data") >>> await store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for LocalStore - >>> data = await store.get_bytes_async("data") + >>> data = await store.get_bytes("data") >>> print(data) b'hello' """ @@ -394,12 +394,12 @@ def get_bytes_sync( Warnings -------- - Do not call this method from async functions. Use ``get_bytes_async()`` instead. + Do not call this method from async functions. Use ``get_bytes()`` instead. See Also -------- - Store.get_bytes : Base implementation with full documentation. - get_bytes_async : Asynchronous version of this method. + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. Examples -------- @@ -455,8 +455,8 @@ async def get_json( See Also -------- Store.get_json : Base implementation with full documentation. - get_json : Synchronous version of this method. - get_bytes_async : Method for retrieving raw bytes without parsing. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. Examples -------- @@ -517,9 +517,9 @@ def get_json_sync( See Also -------- - Store.get_json : Base implementation with full documentation. + Store.get_json_sync : Base implementation with full documentation. get_json : Asynchronous version of this method. - get_bytes : Method for retrieving raw bytes without parsing. + get_bytes_sync : Method for retrieving raw bytes without parsing. Examples -------- diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 5a2593eb25..1568cc6736 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -186,7 +186,7 @@ async def get_bytes( Retrieve raw bytes from the memory store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` for full documentation. Parameters @@ -211,15 +211,15 @@ async def get_bytes( See Also -------- - Store.get_bytes_async : Base implementation with full documentation. - get_bytes : Synchronous version of this method. + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. Examples -------- >>> store = await MemoryStore.open() >>> await store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for MemoryStore - >>> data = await store.get_bytes_async("data") + >>> data = await store.get_bytes("data") >>> print(data) b'hello' """ @@ -263,12 +263,12 @@ def get_bytes_sync( Warnings -------- - Do not call this method from async functions. Use ``get_bytes_async()`` instead. + Do not call this method from async functions. Use ``get_bytes()`` instead. See Also -------- - Store.get_bytes : Base implementation with full documentation. - get_bytes_async : Asynchronous version of this method. + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. Examples -------- @@ -324,8 +324,8 @@ async def get_json( See Also -------- Store.get_json : Base implementation with full documentation. - get_json : Synchronous version of this method. - get_bytes_async : Method for retrieving raw bytes without parsing. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. Examples -------- @@ -386,9 +386,9 @@ def get_json_sync( See Also -------- - Store.get_json : Base implementation with full documentation. + Store.get_json_sync : Base implementation with full documentation. get_json : Asynchronous version of this method. - get_bytes : Method for retrieving raw bytes without parsing. + get_bytes_sync : Method for retrieving raw bytes without parsing. Examples -------- diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index f0cb6dd48f..a56061ae12 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -527,9 +527,9 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new - async def test_get_bytes_async(self, store: S) -> None: + async def test_get_bytes(self, store: S) -> None: """ - Test that the get_bytes_async method reads bytes. + Test that the get_bytes method reads bytes. """ data = b"hello world" key = "zarr.json" @@ -540,7 +540,7 @@ async def test_get_bytes_async(self, store: S) -> None: def test_get_bytes_sync(self, store: S) -> None: """ - Test that the get_bytes method reads bytes. + Test that the get_bytes_sync method reads bytes. """ data = b"hello world" key = "zarr.json" @@ -549,7 +549,7 @@ def test_get_bytes_sync(self, store: S) -> None: async def test_get_json(self, store: S) -> None: """ - Test that the get_bytes_async method reads json. + Test that the get_json method reads json. """ data = {"foo": "bar"} data_bytes = json.dumps(data).encode("utf-8") diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index c47c1adb12..96b7fe9845 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import re from typing import TYPE_CHECKING, Any @@ -9,12 +10,14 @@ import zarr from zarr.core.buffer import Buffer, cpu, gpu +from zarr.core.sync import sync from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype from zarr.core.common import ZarrFormat @@ -76,75 +79,53 @@ async def test_deterministic_size( np.testing.assert_array_equal(a[:3], 1) np.testing.assert_array_equal(a[3:], 0) - async def test_get_bytes_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_bytes_async works with prototype=None.""" - from zarr.core.buffer.core import default_buffer_prototype - + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" data = b"hello world" key = "test_key" await self.set(store, key, self.buffer_cls.from_bytes(data)) - # Test with None (default) - result_none = await store.get_bytes(key) - assert result_none == data - - # Test with explicit prototype - result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) - assert result_proto == data - - def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_bytes works with prototype=None.""" - from zarr.core.buffer.core import default_buffer_prototype - from zarr.core.sync import sync + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" data = b"hello world" key = "test_key" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) - # Test with None (default) - result_none = store.get_bytes_sync(key) - assert result_none == data + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data - # Test with explicit prototype - result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) - assert result_proto == data - - async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: """Test that get_json works with prototype=None.""" - import json - - from zarr.core.buffer.core import default_buffer_prototype - data = {"foo": "bar", "number": 42} key = "test.json" await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) - # Test with None (default) - result_none = await store.get_json(key) - assert result_none == data - - # Test with explicit prototype - result_proto = await store.get_json(key, prototype=default_buffer_prototype()) - assert result_proto == data - - def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_json works with prototype=None.""" - import json - - from zarr.core.buffer.core import default_buffer_prototype - from zarr.core.sync import sync + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" data = {"foo": "bar", "number": 42} key = "test.json" sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) - # Test with None (default) - result_none = store.get_json_sync(key) - assert result_none == data - - # Test with explicit prototype - result_proto = store.get_json_sync(key, prototype=default_buffer_prototype()) - assert result_proto == data + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data # TODO: fix this warning From bdc4ef864b3bcbe422ab981eab6ec94f7af3ac0a Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 11:47:14 +0100 Subject: [PATCH 07/10] refactor new test functions --- tests/test_store/test_local.py | 122 ++++++++++++++------------------- 1 file changed, 52 insertions(+), 70 deletions(-) diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index d6a97110ad..fa4bc7cfc0 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -3,6 +3,7 @@ import json import pathlib import re +from typing import TYPE_CHECKING import numpy as np import pytest @@ -10,13 +11,15 @@ import zarr from zarr import create_array from zarr.core.buffer import Buffer, cpu -from zarr.core.buffer.core import BufferPrototype, default_buffer_prototype from zarr.core.sync import sync from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests from zarr.testing.utils import assert_bytes_equal +if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype + class TestLocalStore(StoreTests[LocalStore, cpu.Buffer]): store_cls = LocalStore @@ -111,6 +114,54 @@ async def test_move( ): await store2.move(destination) + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: @@ -153,72 +204,3 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files - - -async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes works with prototype=None.""" - from zarr.core.buffer import cpu - - store = await LocalStore.open(root=tmp_path) - data = b"hello world" - key = "test_key" - await store.set(key, cpu.Buffer.from_bytes(data)) - - # Test with None (default) - result_none = await store.get_bytes(key) - assert result_none == data - - # Test with explicit prototype - result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) - assert result_proto == data - - -def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes_sync works with prototype=None.""" - from zarr.core.buffer import cpu - from zarr.core.sync import sync - - store = sync(LocalStore.open(root=tmp_path)) - data = b"hello world" - key = "test_key" - sync(store.set(key, cpu.Buffer.from_bytes(data))) - - # Test with None (default) - result_none = store.get_bytes_sync(key) - assert result_none == data - - # Test with explicit prototype - result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) - assert result_proto == data - - -@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) -async def test_get_json_with_prototype_none( - tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype -) -> None: - """Test that get_json works with prototype=None.""" - - store = await LocalStore.open(root=tmp_path) - data = {"foo": "bar", "number": 42} - key = "test.json" - await store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode())) - - # Test with None (default) - result = await store.get_json(key, prototype=buffer_cls) - assert result == data - - -@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) -def test_get_json_sync_with_prototype( - tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype -) -> None: - """Test that get_json_sync works with prototype=None.""" - - store = sync(LocalStore.open(root=tmp_path)) - data = {"foo": "bar", "number": 42} - key = "test.json" - sync(store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode()))) - - # Test with None (default) - result = store.get_json_sync(key, prototype=buffer_cls) - assert result == data From b110768d3557c83a34125cde6382ed4e358301ca Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 16:03:43 +0100 Subject: [PATCH 08/10] add BufferLike as buffer parameter for store methods that allocate memory --- src/zarr/abc/store.py | 90 +++++++++++++++++++++------- src/zarr/experimental/cache_store.py | 18 +++--- src/zarr/storage/_common.py | 21 +++---- src/zarr/storage/_fsspec.py | 42 +++++++++---- src/zarr/storage/_local.py | 45 ++++++++------ src/zarr/storage/_logging.py | 8 +-- src/zarr/storage/_memory.py | 39 +++++++----- src/zarr/storage/_obstore.py | 39 +++++++++--- src/zarr/storage/_wrapper.py | 13 ++-- src/zarr/storage/_zip.py | 27 +++++++-- src/zarr/testing/store.py | 86 ++++++++++++++++++++++++-- tests/test_store/test_wrapper.py | 13 ++-- 12 files changed, 326 insertions(+), 115 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 0e98777ff5..a4eefecf3c 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -7,6 +7,7 @@ from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.sync import sync if TYPE_CHECKING: @@ -14,9 +15,9 @@ from types import TracebackType from typing import Any, Self, TypeAlias - from zarr.core.buffer import Buffer, BufferPrototype +__all__ = ["BufferLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"] -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +BufferLike = type[Buffer] | BufferPrototype @dataclass @@ -183,11 +184,18 @@ def __eq__(self, value: object) -> bool: """Equality comparison.""" ... + @abstractmethod + def _get_default_buffer_class(self) -> type[Buffer]: + """ + Get the default buffer class for this store. + """ + ... + @abstractmethod async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -195,8 +203,12 @@ async def get( Parameters ---------- key : str - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved. - RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned. @@ -210,7 +222,11 @@ async def get( ... async def get_bytes( - self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> bytes: """ Retrieve raw bytes from the store asynchronously. @@ -222,8 +238,12 @@ async def get_bytes( ---------- key : str The key identifying the data to retrieve. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -258,7 +278,11 @@ async def get_bytes( return buffer.to_bytes() def get_bytes_sync( - self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> bytes: """ Retrieve raw bytes from the store synchronously. @@ -271,8 +295,12 @@ def get_bytes_sync( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -309,7 +337,11 @@ def get_bytes_sync( return sync(self.get_bytes(key, prototype=prototype, byte_range=byte_range)) async def get_json( - self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> Any: """ Retrieve and parse JSON data from the store asynchronously. @@ -321,8 +353,12 @@ async def get_json( ---------- key : str The key identifying the JSON data to retrieve. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -359,7 +395,11 @@ async def get_json( return json.loads(await self.get_bytes(key, prototype=prototype, byte_range=byte_range)) def get_json_sync( - self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> Any: """ Retrieve and parse JSON data from the store synchronously. @@ -372,8 +412,12 @@ def get_json_sync( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -417,15 +461,19 @@ def get_json_sync( @abstractmethod async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. Parameters ---------- - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. key_ranges : Iterable[tuple[str, tuple[int | None, int | None]]] Ordered set of key, range pairs, a key may occur multiple times with different ranges @@ -597,7 +645,7 @@ def close(self) -> None: self._is_open = False async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ Retrieve a collection of objects from storage. In general this method does not guarantee diff --git a/src/zarr/experimental/cache_store.py b/src/zarr/experimental/cache_store.py index 3456c94320..e696e0eb0f 100644 --- a/src/zarr/experimental/cache_store.py +++ b/src/zarr/experimental/cache_store.py @@ -6,13 +6,13 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.storage._wrapper import WrapperStore logger = logging.getLogger(__name__) if TYPE_CHECKING: - from zarr.core.buffer.core import Buffer, BufferPrototype + from zarr.core.buffer.core import Buffer class CacheStore(WrapperStore[Store]): @@ -218,7 +218,7 @@ def _remove_from_tracking(self, key: str) -> None: self._key_sizes.pop(key, None) async def _get_try_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Try to get data from cache first, falling back to source store.""" maybe_cached_result = await self._cache.get(key, prototype, byte_range) @@ -246,7 +246,7 @@ async def _get_try_cache( return maybe_fresh_result async def _get_no_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Get data directly from source store and update cache.""" self._misses += 1 @@ -265,7 +265,7 @@ async def _get_no_cache( async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -275,8 +275,12 @@ async def get( ---------- key : str The key to retrieve - prototype : BufferPrototype - Buffer prototype for creating the result buffer + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional Byte range to retrieve diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..e381c65839 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -3,10 +3,10 @@ import importlib.util import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias +from typing import Any, Literal, Self, TypeAlias -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.core.buffer import Buffer from zarr.core.common import ( ANY_ACCESS_MODE, ZARR_JSON, @@ -26,9 +26,6 @@ else: FSMap = None -if TYPE_CHECKING: - from zarr.core.buffer import BufferPrototype - def _dereference_path(root: str, path: str) -> str: if not isinstance(root, str): @@ -145,7 +142,7 @@ async def open(cls, store: Store, path: str, mode: AccessModeLiteral | None = No async def get( self, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -153,8 +150,12 @@ async def get( Parameters ---------- - prototype : BufferPrototype, optional - The buffer prototype to use when reading the bytes. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + store's ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional The range of bytes to read. @@ -164,7 +165,7 @@ async def get( The read bytes, or None if the key does not exist. """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self.store._get_default_buffer_class() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) async def set(self, value: Buffer) -> None: diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index f9e4ed375d..c8a80a9554 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -8,13 +8,15 @@ from packaging.version import parse as parse_version from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer +from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path @@ -25,8 +27,6 @@ from fsspec.asyn import AsyncFileSystem from fsspec.mapping import FSMap - from zarr.core.buffer import BufferPrototype - ALLOWED_EXCEPTIONS: tuple[type[Exception], ...] = ( FileNotFoundError, @@ -273,22 +273,34 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if not self._is_open: await self._open() + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + path = _dereference_path(self.path, key) try: if byte_range is None: - value = prototype.buffer.from_bytes(await self.fs._cat_file(path)) + value = buffer_cls.from_bytes(await self.fs._cat_file(path)) elif isinstance(byte_range, RangeByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file( path, start=byte_range.start, @@ -296,11 +308,11 @@ async def get( ) ) elif isinstance(byte_range, OffsetByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=byte_range.offset, end=None) ) elif isinstance(byte_range, SuffixByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=-byte_range.suffix, end=None) ) else: @@ -310,7 +322,7 @@ async def get( except OSError as e: if "not satisfiable" in str(e): # this is an s3-specific condition we probably don't want to leak - return prototype.buffer.from_bytes(b"") + return buffer_cls.from_bytes(b"") raise else: return value @@ -367,10 +379,18 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + if key_ranges: # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. key_ranges = list(key_ranges) @@ -403,7 +423,7 @@ async def get_partial_values( if isinstance(r, Exception) and not isinstance(r, self.allowed_exceptions): raise r - return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res] + return [None if isinstance(r, Exception) else buffer_cls.from_bytes(r) for r in res] async def list(self) -> AsyncIterator[str]: # docstring inherited diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 9fb3f8b6ad..351b69b275 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -11,37 +11,42 @@ from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import AccessModeLiteral, concurrent_map if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator - from zarr.core.buffer import BufferPrototype +def _get(path: Path, prototype: BufferLike, byte_range: ByteRequest | None) -> Buffer: + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype -def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: if byte_range is None: - return prototype.buffer.from_bytes(path.read_bytes()) + return buffer_cls.from_bytes(path.read_bytes()) with path.open("rb") as f: size = f.seek(0, io.SEEK_END) if isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) elif isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) elif isinstance(byte_range, SuffixByteRequest): f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) if sys.platform == "win32": @@ -187,15 +192,19 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() if not self._is_open: await self._open() assert isinstance(key, str) @@ -208,10 +217,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() args = [] for key, byte_range in key_ranges: assert isinstance(key, str) @@ -310,7 +321,7 @@ async def get_bytes( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -355,14 +366,14 @@ async def get_bytes( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) def get_bytes_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -411,14 +422,14 @@ def get_bytes_sync( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) async def get_json( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -470,14 +481,14 @@ async def get_json( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_json(key, prototype=prototype, byte_range=byte_range) def get_json_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -533,7 +544,7 @@ def get_json_sync( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) async def move(self, dest_root: Path | str) -> None: diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index dd20d49ae5..7d82dac948 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -8,14 +8,14 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Self, TypeVar -from zarr.abc.store import Store +from zarr.abc.store import BufferLike, Store from zarr.storage._wrapper import WrapperStore if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable from zarr.abc.store import ByteRequest - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer counter: defaultdict[str, int] @@ -165,7 +165,7 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -174,7 +174,7 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 1568cc6736..15ee3855df 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -3,8 +3,8 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, Self -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, gpu +from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.core.buffer import Buffer, BufferPrototype, gpu from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index @@ -12,8 +12,6 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, MutableMapping - from zarr.core.buffer import BufferPrototype - logger = getLogger(__name__) @@ -60,6 +58,10 @@ def with_read_only(self, read_only: bool = False) -> MemoryStore: read_only=read_only, ) + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + async def clear(self) -> None: # docstring inherited self._store_dict.clear() @@ -80,25 +82,30 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype if not self._is_open: await self._open() assert isinstance(key, str) try: value = self._store_dict[key] start, stop = _normalize_byte_range_index(value, byte_range) - return prototype.buffer.from_buffer(value[start:stop]) + return buffer_cls.from_buffer(value[start:stop]) except KeyError: return None async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited @@ -179,7 +186,7 @@ async def get_bytes( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -224,14 +231,14 @@ async def get_bytes( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) def get_bytes_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -280,14 +287,14 @@ def get_bytes_sync( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) async def get_json( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -339,14 +346,14 @@ async def get_json( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_json(key, prototype=prototype, byte_range=byte_range) def get_json_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -402,7 +409,7 @@ def get_json_sync( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..baa469c430 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -7,12 +7,15 @@ from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) +from zarr.core.buffer import BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.core.config import config @@ -23,7 +26,7 @@ from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange from obstore.store import ObjectStore as _UpstreamObjectStore - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer __all__ = ["ObjectStore"] @@ -94,26 +97,40 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + from zarr.core.buffer.core import default_buffer_prototype + + return default_buffer_prototype().buffer + async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: # docstring inherited import obstore as obs + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + try: if byte_range is None: resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, RangeByteRequest): bytes = await obs.get_range_async( self.store, key, start=byte_range.start, end=byte_range.end ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + return buffer_cls.from_bytes(bytes) # type: ignore[arg-type] elif isinstance(byte_range, OffsetByteRequest): resp = await obs.get_async( self.store, key, options={"range": {"offset": byte_range.offset}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, SuffixByteRequest): # some object stores (Azure) don't support suffix requests. In this # case, our workaround is to first get the length of the object and then @@ -122,7 +139,7 @@ async def get( resp = await obs.get_async( self.store, key, options={"range": {"suffix": byte_range.suffix}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] except obs.exceptions.NotSupportedError: head_resp = await obs.head_async(self.store, key) file_size = head_resp["size"] @@ -133,7 +150,7 @@ async def get( start=file_size - suffix_len, length=suffix_len, ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + return buffer_cls.from_bytes(buffer) # type: ignore[arg-type] else: raise ValueError(f"Unexpected byte_range, got {byte_range}") except _ALLOWED_EXCEPTIONS: @@ -141,10 +158,16 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike - _get_partial_values expects BufferPrototype + if not isinstance(prototype, BufferPrototype): + # Convert raw buffer class to BufferPrototype + prototype = default_buffer_prototype() return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) async def exists(self, key: str) -> bool: diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index 64a5b2d83c..b105dcf0d2 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -9,9 +9,8 @@ from zarr.abc.buffer import Buffer from zarr.abc.store import ByteRequest - from zarr.core.buffer import BufferPrototype -from zarr.abc.store import Store +from zarr.abc.store import BufferLike, Store T_Store = TypeVar("T_Store", bound=Store) @@ -84,14 +83,18 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"WrapperStore({self._store.__class__.__name__}, '{self._store}')" + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return self._store._get_default_buffer_class() + async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: return await self._store.get(key, prototype, byte_range) async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: return await self._store.get_partial_values(prototype, key_ranges) @@ -139,7 +142,7 @@ def close(self) -> None: self._store.close() async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: async for req in self._store._get_many(requests): yield req diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 72bf9e335a..64fe902e52 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -16,6 +17,7 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -143,22 +145,31 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.path == other.path + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + def _get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike, byte_range: ByteRequest | None = None, ) -> Buffer | None: if not self._is_open: self._sync_open() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype # docstring inherited try: with self._zf.open(key) as f: # will raise KeyError if byte_range is None: - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) elif isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) size = f.seek(0, os.SEEK_END) if isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) @@ -166,17 +177,19 @@ def _get( f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) except KeyError: return None async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() assert isinstance(key, str) with self._lock: @@ -184,10 +197,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() out = [] with self._lock: for key, byte_range in key_ranges: diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index a56061ae12..7613dee04a 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -12,11 +12,11 @@ from typing import Any from zarr.abc.store import ByteRequest - from zarr.core.buffer.core import BufferPrototype import pytest from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -244,6 +244,32 @@ async def test_get_raises(self, store: S) -> None: with pytest.raises((ValueError, TypeError), match=r"Unexpected byte_range, got.*"): await store.get("c/0", prototype=default_buffer_prototype(), byte_range=(0, 2)) # type: ignore[arg-type] + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) + async def test_get_with_buffer_like(self, store: S, prototype: BufferLike | None) -> None: + """ + Test that store.get() works with all BufferLike variants: + - None (uses store's default) + - BufferPrototype instance + - Raw Buffer class + """ + data = b"\x01\x02\x03\x04" + key = "test_buffer_like" + data_buf = self.buffer_cls.from_bytes(data) + await self.set(store, key, data_buf) + + # Get with the parametrized prototype + observed = await store.get(key, prototype=prototype) + assert observed is not None + assert_bytes_equal(observed, data_buf) + async def test_get_many(self, store: S) -> None: """ Ensure that multiple keys can be retrieved at once with the _get_many method. @@ -376,6 +402,54 @@ async def test_get_partial_values( obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) ) + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) + async def test_get_partial_values_with_buffer_like( + self, store: S, prototype: BufferLike | None + ) -> None: + """ + Test that store.get_partial_values() works with all BufferLike variants: + - None (uses store's default) + - BufferPrototype instance + - Raw Buffer class + """ + key_ranges: list[tuple[str, ByteRequest | None]] = [ + ("c/0", RangeByteRequest(0, 2)), + ("c/1", None), + ("c/2", SuffixByteRequest(2)), + ] + + # put all of the data + for key, _ in key_ranges: + await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) + + # read back with the parametrized prototype + observed_maybe = await store.get_partial_values(prototype=prototype, key_ranges=key_ranges) + + observed: list[Buffer] = [] + expected: list[Buffer] = [] + + for obs in observed_maybe: + assert obs is not None + observed.append(obs) + + for idx in range(len(observed)): + key, byte_range = key_ranges[idx] + result = await store.get(key, prototype=prototype, byte_range=byte_range) + assert result is not None + expected.append(result) + + assert all( + obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) + ) + async def test_exists(self, store: S) -> None: assert not await store.exists("foo") await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) @@ -604,7 +678,7 @@ async def set(self, key: str, value: Buffer) -> None: await self._store.set(key, value) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: """ Add latency to the ``get`` method. @@ -615,8 +689,12 @@ async def get( ---------- key : str The key to get - prototype : BufferPrototype - The BufferPrototype to use. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional An optional byte range. diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py index b34a63d5d0..c5f2240297 100644 --- a/tests/test_store/test_wrapper.py +++ b/tests/test_store/test_wrapper.py @@ -4,7 +4,7 @@ import pytest -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.core.buffer.cpu import buffer_prototype @@ -14,8 +14,6 @@ if TYPE_CHECKING: from pathlib import Path - from zarr.core.buffer.core import BufferPrototype - class StoreKwargs(TypedDict): store: LocalStore @@ -111,10 +109,13 @@ async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> # define a class that prints when it sets class NoisyGetter(WrapperStore[Any]): async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None - ) -> None: + self, + key: str, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: print(f"getting {key}") - await super().get(key, prototype=prototype, byte_range=byte_range) + return await super().get(key, prototype=prototype, byte_range=byte_range) key = "foo" value = CPUBuffer.from_bytes(b"bar") From 6b9de9db2d6594f26a65374b0b25cdcb12b15d7b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 16:57:38 +0100 Subject: [PATCH 09/10] implement default on store abc --- src/zarr/abc/store.py | 6 +++--- src/zarr/storage/_fsspec.py | 5 ----- src/zarr/storage/_local.py | 5 ----- src/zarr/storage/_memory.py | 5 ----- src/zarr/storage/_obstore.py | 6 ------ src/zarr/storage/_wrapper.py | 4 ---- src/zarr/storage/_zip.py | 5 ----- 7 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index a4eefecf3c..25aaba4aa9 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.core.sync import sync if TYPE_CHECKING: @@ -184,12 +185,11 @@ def __eq__(self, value: object) -> bool: """Equality comparison.""" ... - @abstractmethod def _get_default_buffer_class(self) -> type[Buffer]: """ - Get the default buffer class for this store. + Get the default buffer class. """ - ... + return default_buffer_prototype().buffer @abstractmethod async def get( diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index c8a80a9554..b16712c786 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -16,7 +16,6 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.buffer.core import default_buffer_prototype from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path @@ -273,10 +272,6 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - async def get( self, key: str, diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 351b69b275..f991765723 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,7 +19,6 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import AccessModeLiteral, concurrent_map if TYPE_CHECKING: @@ -192,10 +191,6 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - async def get( self, key: str, diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 15ee3855df..c28dc910b4 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -5,7 +5,6 @@ from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.core.buffer import Buffer, BufferPrototype, gpu -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index @@ -58,10 +57,6 @@ def with_read_only(self, read_only: bool = False) -> MemoryStore: read_only=read_only, ) - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - async def clear(self) -> None: # docstring inherited self._store_dict.clear() diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index baa469c430..aff000afe9 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -97,12 +97,6 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - from zarr.core.buffer.core import default_buffer_prototype - - return default_buffer_prototype().buffer - async def get( self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index b105dcf0d2..ca3609009e 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -83,10 +83,6 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"WrapperStore({self._store.__class__.__name__}, '{self._store}')" - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return self._store._get_default_buffer_class() - async def get( self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 64fe902e52..0348eeedd8 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -17,7 +17,6 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.buffer.core import default_buffer_prototype if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -145,10 +144,6 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.path == other.path - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - def _get( self, key: str, From 281538a2966fa15c229a6de6352c2fefd01b53d7 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 18:00:08 +0100 Subject: [PATCH 10/10] consolidate prototype testing --- src/zarr/testing/store.py | 108 +++++++++----------------------------- 1 file changed, 25 insertions(+), 83 deletions(-) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 7613dee04a..55e3687f20 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -23,7 +23,7 @@ Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.core.buffer import Buffer, cpu, default_buffer_prototype from zarr.core.sync import _collect_aiterator, sync from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal @@ -202,6 +202,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: ): await reader.delete("foo") + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize( ("data", "byte_range"), @@ -213,13 +222,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: (b"", None), ], ) - async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None: + async def test_get( + self, store: S, key: str, data: bytes, byte_range: ByteRequest, prototype: BufferLike | None + ) -> None: """ Ensure that data can be read from the store using the store.get method. """ data_buf = self.buffer_cls.from_bytes(data) await self.set(store, key, data_buf) - observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range) + observed = await store.get(key, prototype=prototype, byte_range=byte_range) start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range) expected = data_buf[start:stop] assert_bytes_equal(observed, expected) @@ -244,32 +255,6 @@ async def test_get_raises(self, store: S) -> None: with pytest.raises((ValueError, TypeError), match=r"Unexpected byte_range, got.*"): await store.get("c/0", prototype=default_buffer_prototype(), byte_range=(0, 2)) # type: ignore[arg-type] - @pytest.mark.parametrize( - "prototype", - [ - None, # Should use store's default buffer class - default_buffer_prototype(), # BufferPrototype instance - default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) - ], - ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], - ) - async def test_get_with_buffer_like(self, store: S, prototype: BufferLike | None) -> None: - """ - Test that store.get() works with all BufferLike variants: - - None (uses store's default) - - BufferPrototype instance - - Raw Buffer class - """ - data = b"\x01\x02\x03\x04" - key = "test_buffer_like" - data_buf = self.buffer_cls.from_bytes(data) - await self.set(store, key, data_buf) - - # Get with the parametrized prototype - observed = await store.get(key, prototype=prototype) - assert observed is not None - assert_bytes_equal(observed, data_buf) - async def test_get_many(self, store: S) -> None: """ Ensure that multiple keys can be retrieved at once with the _get_many method. @@ -358,6 +343,15 @@ async def test_set_many(self, store: S) -> None: for k, v in store_dict.items(): assert (await self.get(store, k)).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize( "key_ranges", [ @@ -372,65 +366,13 @@ async def test_set_many(self, store: S) -> None: ], ) async def test_get_partial_values( - self, store: S, key_ranges: list[tuple[str, ByteRequest]] + self, store: S, key_ranges: list[tuple[str, ByteRequest]], prototype: BufferLike | None ) -> None: # put all of the data for key, _ in key_ranges: await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it - observed_maybe = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=key_ranges - ) - - observed: list[Buffer] = [] - expected: list[Buffer] = [] - - for obs in observed_maybe: - assert obs is not None - observed.append(obs) - - for idx in range(len(observed)): - key, byte_range = key_ranges[idx] - result = await store.get( - key, prototype=default_buffer_prototype(), byte_range=byte_range - ) - assert result is not None - expected.append(result) - - assert all( - obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) - ) - - @pytest.mark.parametrize( - "prototype", - [ - None, # Should use store's default buffer class - default_buffer_prototype(), # BufferPrototype instance - default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) - ], - ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], - ) - async def test_get_partial_values_with_buffer_like( - self, store: S, prototype: BufferLike | None - ) -> None: - """ - Test that store.get_partial_values() works with all BufferLike variants: - - None (uses store's default) - - BufferPrototype instance - - Raw Buffer class - """ - key_ranges: list[tuple[str, ByteRequest | None]] = [ - ("c/0", RangeByteRequest(0, 2)), - ("c/1", None), - ("c/2", SuffixByteRequest(2)), - ] - - # put all of the data - for key, _ in key_ranges: - await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) - - # read back with the parametrized prototype observed_maybe = await store.get_partial_values(prototype=prototype, key_ranges=key_ranges) observed: list[Buffer] = [] @@ -442,7 +384,7 @@ async def test_get_partial_values_with_buffer_like( for idx in range(len(observed)): key, byte_range = key_ranges[idx] - result = await store.get(key, prototype=prototype, byte_range=byte_range) + result = await store.get(key, prototype=cpu.Buffer, byte_range=byte_range) assert result is not None expected.append(result)