Skip to content
1 change: 1 addition & 0 deletions changes/3588.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix missing `_lock` attribute error in ZipStore when calling some methods of before opening it
24 changes: 18 additions & 6 deletions src/zarr/storage/_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def _sync_open(self) -> None:

self._is_open = True

def _sync_ensure_open(self) -> None:
if not self._is_open:
self._sync_open()

async def _open(self) -> None:
self._sync_open()

Expand All @@ -120,12 +124,16 @@ def __setstate__(self, state: dict[str, Any]) -> None:

def close(self) -> None:
# docstring inherited
self._sync_ensure_open()

super().close()
with self._lock:
self._zf.close()

async def clear(self) -> None:
# docstring inherited
self._sync_ensure_open()

with self._lock:
self._check_writable()
self._zf.close()
Expand All @@ -149,8 +157,7 @@ def _get(
prototype: BufferPrototype,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
if not self._is_open:
self._sync_open()
self._sync_ensure_open()
# docstring inherited
try:
with self._zf.open(key) as f: # will raise KeyError
Expand Down Expand Up @@ -188,15 +195,15 @@ async def get_partial_values(
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
# docstring inherited
self._sync_ensure_open()
out = []
with self._lock:
for key, byte_range in key_ranges:
out.append(self._get(key, prototype=prototype, byte_range=byte_range))
return out

def _set(self, key: str, value: Buffer) -> None:
if not self._is_open:
self._sync_open()
self._sync_ensure_open()
# generally, this should be called inside a lock
keyinfo = zipfile.ZipInfo(filename=key, date_time=time.localtime(time.time())[:6])
keyinfo.compress_type = self.compression
Expand All @@ -210,8 +217,7 @@ def _set(self, key: str, value: Buffer) -> None:
async def set(self, key: str, value: Buffer) -> None:
# docstring inherited
self._check_writable()
if not self._is_open:
self._sync_open()
self._sync_ensure_open()
assert isinstance(key, str)
if not isinstance(value, Buffer):
raise TypeError(
Expand All @@ -222,6 +228,8 @@ async def set(self, key: str, value: Buffer) -> None:

async def set_if_not_exists(self, key: str, value: Buffer) -> None:
self._check_writable()
self._sync_ensure_open()

with self._lock:
members = self._zf.namelist()
if key not in members:
Expand All @@ -245,6 +253,8 @@ async def delete(self, key: str) -> None:

async def exists(self, key: str) -> bool:
# docstring inherited
self._sync_ensure_open()

with self._lock:
try:
self._zf.getinfo(key)
Expand All @@ -255,6 +265,8 @@ async def exists(self, key: str) -> bool:

async def list(self) -> AsyncIterator[str]:
# docstring inherited
self._sync_ensure_open()

with self._lock:
for key in self._zf.namelist():
yield key
Expand Down
14 changes: 14 additions & 0 deletions tests/test_store/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,17 @@ async def test_move(self, tmp_path: Path) -> None:
assert destination.exists()
assert not origin.exists()
assert np.array_equal(array[...], np.arange(10))

async def test_lock_present(self, store: ZipStore) -> None:
buf = cpu.Buffer.from_bytes(b"bar")
await store.set("foo", buf)
await store.set_if_not_exists("foo", buf)
await store.exists("foo")
await store.get("foo", default_buffer_prototype())

async for _ in store.list():
pass

await store.clear()

store.close()