From 31d1f642072a35f03423c656a68e77aa4144f6a1 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Fri, 23 May 2025 18:45:26 +0000 Subject: [PATCH 1/4] feat: Added read_time as a parameter to various calls (async classes) --- .../cloud/firestore_v1/async_aggregation.py | 20 ++ google/cloud/firestore_v1/async_client.py | 18 +- google/cloud/firestore_v1/async_collection.py | 26 +- google/cloud/firestore_v1/async_document.py | 18 +- google/cloud/firestore_v1/async_query.py | 31 ++- .../cloud/firestore_v1/async_transaction.py | 17 ++ tests/system/test_system.py | 9 +- tests/system/test_system_async.py | 243 ++++++++++++++++++ tests/unit/v1/test_async_aggregation.py | 71 ++++- tests/unit/v1/test_async_client.py | 55 +++- tests/unit/v1/test_async_collection.py | 59 ++++- tests/unit/v1/test_async_document.py | 56 +++- tests/unit/v1/test_async_query.py | 99 ++++++- tests/unit/v1/test_async_transaction.py | 40 ++- 14 files changed, 688 insertions(+), 74 deletions(-) diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index 3f3a1b9f43..63a29f3447 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -20,6 +20,8 @@ """ from __future__ import annotations +import datetime + from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Union from google.api_core import gapic_v1 @@ -55,6 +57,7 @@ async def get( timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[List[AggregationResult]]: """Runs the aggregation query. @@ -75,6 +78,10 @@ async def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: QueryResultsList[List[AggregationResult]]: The aggregation query results. @@ -87,6 +94,7 @@ async def get( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) try: result = [aggregation async for aggregation in stream_result] @@ -106,6 +114,7 @@ async def _make_stream( retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[List[AggregationResult] | query_profile_pb.ExplainMetrics, Any]: """Internal method for stream(). Runs the aggregation query. @@ -130,6 +139,10 @@ async def _make_stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: List[AggregationResult] | query_profile_pb.ExplainMetrics: @@ -143,6 +156,7 @@ async def _make_stream( retry, timeout, explain_options, + read_time, ) response_iterator = await self._client._firestore_api.run_aggregation_query( @@ -167,6 +181,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncStreamGenerator[List[AggregationResult]]: """Runs the aggregation query. @@ -190,6 +205,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: `AsyncStreamGenerator[List[AggregationResult]]`: @@ -201,5 +220,6 @@ def stream( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) return AsyncStreamGenerator(inner_generator, explain_options) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 275bcb9b61..f675150ccd 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -25,6 +25,8 @@ """ from __future__ import annotations +import datetime + from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union from google.api_core import gapic_v1 @@ -227,6 +229,8 @@ async def get_all( transaction: AsyncTransaction | None = None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieve a batch of documents. @@ -261,13 +265,17 @@ async def get_all( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ request, reference_map, kwargs = self._prep_get_all( - references, field_paths, transaction, retry, timeout + references, field_paths, transaction, retry, timeout, read_time ) response_iterator = await self._firestore_api.batch_get_documents( @@ -283,6 +291,8 @@ async def collections( self, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator[AsyncCollectionReference, Any]: """List top-level collections of the client's database. @@ -291,12 +301,16 @@ async def collections( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: iterator of subcollections of the current document. """ - request, kwargs = self._prep_collections(retry, timeout) + request, kwargs = self._prep_collections(retry, timeout, read_time) iterator = await self._firestore_api.list_collection_ids( request=request, metadata=self._rpc_metadata, diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 8c832b8f4c..08a9ae516f 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -15,6 +15,8 @@ """Classes for representing collections for the Google Cloud Firestore API.""" from __future__ import annotations +import datetime + from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple from google.api_core import gapic_v1 @@ -162,6 +164,8 @@ async def list_documents( page_size: int | None = None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator[DocumentReference, None]: """List all subdocuments of the current collection. @@ -173,6 +177,10 @@ async def list_documents( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.DocumentReference`]: @@ -180,7 +188,9 @@ async def list_documents( collection does not exist at the time of `snapshot`, the iterator will be empty """ - request, kwargs = self._prep_list_documents(page_size, retry, timeout) + request, kwargs = self._prep_list_documents( + page_size, retry, timeout, read_time + ) iterator = await self._client._firestore_api.list_documents( request=request, @@ -197,6 +207,7 @@ async def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[DocumentSnapshot]: """Read the documents in this collection. @@ -216,6 +227,10 @@ async def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not allowed). @@ -227,6 +242,8 @@ async def get( query, kwargs = self._prep_get_or_stream(retry, timeout) if explain_options is not None: kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time return await query.get(transaction=transaction, **kwargs) @@ -237,6 +254,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncStreamGenerator[DocumentSnapshot]: """Read the documents in this collection. @@ -268,6 +286,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: `AsyncStreamGenerator[DocumentSnapshot]`: A generator of the query @@ -276,5 +298,7 @@ def stream( query, kwargs = self._prep_get_or_stream(retry, timeout) if explain_options: kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time return query.stream(transaction=transaction, **kwargs) diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 78c71b33fc..c3ebfbe0cc 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -329,6 +329,8 @@ async def get( transaction=None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. @@ -351,6 +353,10 @@ async def get( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: :class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot`: @@ -362,7 +368,9 @@ async def get( """ from google.cloud.firestore_v1.base_client import _parse_batch_get - request, kwargs = self._prep_batch_get(field_paths, transaction, retry, timeout) + request, kwargs = self._prep_batch_get( + field_paths, transaction, retry, timeout, read_time + ) response_iter = await self._client._firestore_api.batch_get_documents( request=request, @@ -397,6 +405,8 @@ async def collections( page_size: int | None = None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator: """List subcollections of the current document. @@ -408,6 +418,10 @@ async def collections( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: @@ -415,7 +429,7 @@ async def collections( document does not exist at the time of `snapshot`, the iterator will be empty """ - request, kwargs = self._prep_collections(page_size, retry, timeout) + request, kwargs = self._prep_collections(page_size, retry, timeout, read_time) iterator = await self._client._firestore_api.list_collection_ids( request=request, diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index d4fd45fa46..f8ebcbb920 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,6 +20,8 @@ """ from __future__ import annotations +import datetime + from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Type from google.api_core import gapic_v1 @@ -182,6 +184,7 @@ async def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[DocumentSnapshot]: """Read the documents in the collection that match this query. @@ -201,6 +204,10 @@ async def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -230,6 +237,7 @@ async def get( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) try: result_list = [d async for d in result] @@ -336,6 +344,7 @@ async def _make_stream( retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[DocumentSnapshot | query_profile_pb.ExplainMetrics, Any]: """Internal method for stream(). Read the documents in the collection that match this query. @@ -368,6 +377,10 @@ async def _make_stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. Yields: [:class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot` \ @@ -381,6 +394,7 @@ async def _make_stream( retry, timeout, explain_options, + read_time, ) response_iterator = await self._client._firestore_api.run_query( @@ -412,6 +426,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncStreamGenerator[DocumentSnapshot]: """Read the documents in the collection that match this query. @@ -443,6 +458,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. Returns: `AsyncStreamGenerator[DocumentSnapshot]`: @@ -453,6 +472,7 @@ def stream( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time ) return AsyncStreamGenerator(inner_generator, explain_options) @@ -514,6 +534,8 @@ async def get_partitions( partition_count, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[QueryPartition, None]: """Partition a query for parallelization. @@ -529,8 +551,15 @@ async def get_partitions( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. """ - request, kwargs = self._prep_get_partitions(partition_count, retry, timeout) + request, kwargs = self._prep_get_partitions( + partition_count, retry, timeout, read_time + ) + pager = await self._client._firestore_api.partition_query( request=request, metadata=self._client._rpc_metadata, diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 038710929b..2e89f88d58 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -15,6 +15,8 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" from __future__ import annotations +import datetime + from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional from google.api_core import exceptions, gapic_v1 @@ -154,6 +156,8 @@ async def get_all( references: list, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieves multiple documents from Firestore. @@ -164,12 +168,18 @@ async def get_all( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time return await self._client.get_all(references, transaction=self, **kwargs) async def get( @@ -179,6 +189,7 @@ async def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[DocumentSnapshot, Any] | AsyncStreamGenerator[DocumentSnapshot]: """ Retrieve a document or a query result from the database. @@ -195,6 +206,10 @@ async def get( Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. Can only be used when running a query, not a document reference. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: DocumentSnapshot: The next document snapshot that fulfills the query, @@ -206,6 +221,8 @@ async def get( reference. """ kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time if isinstance(ref_or_query, AsyncDocumentReference): if explain_options is not None: raise ValueError( diff --git a/tests/system/test_system.py b/tests/system/test_system.py index d10196ad16..8dc43a4a7d 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1142,10 +1142,7 @@ def test_collection_add(client, cleanup, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -@pytest.mark.parametrize("use_python_datetime", [True, False]) -def test_list_collections_with_read_time( - client, cleanup, database, use_python_datetime -): +def test_list_collections_with_read_time(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. collection_id = "coll-add" + UNIQUE_RESOURCE_ID @@ -1155,15 +1152,11 @@ def test_list_collections_with_read_time( data1 = {"foo": "bar"} update_time1, document_ref1 = collection.add(data1) - if use_python_datetime: - update_time1 = datetime.datetime.now(tz=datetime.timezone.utc) cleanup(document_ref1.delete) assert set(collection.list_documents()) == {document_ref1} data2 = {"bar": "baz"} update_time2, document_ref2 = collection.add(data2) - if use_python_datetime: - update_time2 = datetime.datetime.now(tz=datetime.timezone.utc) cleanup(document_ref2.delete) assert set(collection.list_documents()) == {document_ref1, document_ref2} assert set(collection.list_documents(read_time=update_time1)) == {document_ref1} diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 200be7d8ab..6ab22d6d8f 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -234,6 +234,39 @@ async def test_create_document(client, cleanup, database): assert stored_data == expected_data +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_collections_w_read_time(client, cleanup, database): + first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + first_document_id = "doc" + UNIQUE_RESOURCE_ID + first_document = client.document(first_collection_id, first_document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(first_document.delete) + + data = {"status": "new"} + write_result = await first_document.create(data) + read_time = write_result.update_time + num_collections = len([x async for x in client.collections(retry=RETRIES)]) + + second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" + second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" + second_document = client.document(second_collection_id, second_document_id) + cleanup(second_document.delete) + await second_document.create(data) + + # Test that listing current collections does have the second id. + curr_collections = [x async for x in client.collections(retry=RETRIES)] + assert len(curr_collections) > num_collections + ids = [collection.id for collection in curr_collections] + assert second_collection_id in ids + assert first_collection_id in ids + + # We're just testing that we added one collection at read_time, not two. + collections = [x async for x in client.collections(retry=RETRIES, read_time=read_time)] + ids = [collection.id for collection in collections] + assert second_collection_id not in ids + assert first_collection_id in ids + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID @@ -260,6 +293,42 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_document_collections_w_read_time(client, cleanup, database): + collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID + document_id = "doc" + UNIQUE_RESOURCE_ID + document = client.document(collection_id, document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + data = {"now": firestore.SERVER_TIMESTAMP} + document.create(data) + + original_child_ids = ["child1", "child2"] + read_time = None + + for child_id in original_child_ids: + subcollection = document.collection(child_id) + update_time, subdoc = await subcollection.add({"foo": "bar"}) + read_time = ( + update_time if read_time is None or update_time > read_time else read_time + ) + cleanup(subdoc.delete) + + update_time, newdoc = await document.collection("child3").add({"foo": "bar"}) + cleanup(newdoc.delete) + assert update_time > read_time + + # Compare the query at read_time to the query at new update time. + original_children = [doc async for doc in document.collections(read_time=read_time)] + assert sorted(child.id for child in original_children) == sorted(original_child_ids) + + original_children = [doc async for doc in document.collections()] + assert sorted(child.id for child in original_children) == sorted( + original_child_ids + ["child3"] + ) + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID @@ -1062,6 +1131,31 @@ async def test_collection_add(client, cleanup, database): assert set([i async for i in collection3.list_documents()]) == {document_ref5} +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_list_collections_with_read_time(client, cleanup, database): + # TODO(microgen): list_documents is returning a generator, not a list. + # Consider if this is desired. Also, Document isn't hashable. + collection_id = "coll-add" + UNIQUE_RESOURCE_ID + collection = client.collection(collection_id) + + assert set([i async for i in collection.list_documents()]) == set() + + data1 = {"foo": "bar"} + update_time1, document_ref1 = await collection.add(data1) + cleanup(document_ref1.delete) + assert set([i async for i in collection.list_documents()]) == {document_ref1} + + data2 = {"bar": "baz"} + update_time2, document_ref2 = await collection.add(data2) + cleanup(document_ref2.delete) + assert set([i async for i in collection.list_documents()]) == {document_ref1, document_ref2} + assert set([i async for i in collection.list_documents(read_time=update_time1)]) == {document_ref1} + assert set([i async for i in collection.list_documents(read_time=update_time2)]) == { + document_ref1, + document_ref2, + } + + @pytest_asyncio.fixture async def query_docs(client): collection_id = "qs" + UNIQUE_RESOURCE_ID @@ -1389,6 +1483,46 @@ async def test_query_stream_or_get_w_explain_options_analyze_false( _verify_explain_metrics_analyze_false(explain_metrics) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_query_stream_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find the most recent read_time in collections + read_time = max( + [(await docref.get()).read_time async for docref in collection.list_documents()] + ) + new_data = { + "a": 9000, + "b": 1, + "c": [10000, 1000], + "stats": {"sum": 9001, "product": 9000}, + } + _, new_ref = await collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + # Compare query at read_time to query at current time. + query = collection.where(filter=FieldFilter("b", "==", 1)) + values = { + snapshot.id: snapshot.to_dict() async + for snapshot in query.stream(read_time=read_time) + } + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref + + new_values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client @@ -1853,6 +1987,8 @@ async def test_get_all(client, cleanup, database): data3 = {"a": {"b": 5, "c": 6}, "d": 7, "e": 100} write_result3 = await document3.create(data3) + read_time = write_result3.update_time + # 0. Get 3 unique documents, one of which is missing. snapshots = [i async for i in client.get_all([document1, document2, document3])] @@ -1891,6 +2027,19 @@ async def test_get_all(client, cleanup, database): restricted3 = {"a": {"b": data3["a"]["b"]}, "d": data3["d"]} check_snapshot(snapshot3, document3, restricted3, write_result3) + # 3. Use ``read_time`` in ``get_all`` + new_data = {"a": {"b": 8, "c": 9}, "d": 10, "e": 1010} + await document1.update(new_data) + await document2.create(new_data) + await document3.update(new_data) + + snapshots = [i async for i in + client.get_all([document1, document2, document3], read_time=read_time) + ] + assert snapshots[0].exists + assert snapshots[1].exists + assert not snapshots[2].exists + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_live_bulk_writer(client, cleanup, database): @@ -2765,6 +2914,50 @@ async def test_async_avg_query_stream_w_explain_options_analyze_false( explain_metrics.execution_stats +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] +) +async def test_aggregation_queries_with_read_time( + collection, async_query, cleanup, database, aggregation_type, expected_value +): + """ + Ensure that all aggregation queries work when read_time is passed into + a query..().get() method + """ + # Find the most recent read_time in collections + read_time = max( + [(await docref.get()).read_time async for docref in collection.list_documents()] + ) + document_data = { + "a": 1, + "b": 9000, + "c": [1, 123123123], + "stats": {"sum": 9001, "product": 9000}, + } + + _, doc_ref = await collection.add(document_data) + cleanup(doc_ref.delete) + + if aggregation_type == "count": + aggregation_query = async_query.count() + elif aggregation_type == "sum": + aggregation_query = collection.sum("stats.product") + elif aggregation_type == "avg": + aggregation_query = collection.avg("stats.product") + + # Check that adding the new document data affected the results of the aggregation queries. + new_result = await aggregation_query.get() + assert len(new_result) == 1 + for r in new_result[0]: + assert r.value != expected_value + + old_result = await aggregation_query.get(read_time=read_time) + assert len(old_result) == 1 + for r in old_result[0]: + assert r.value == expected_value + + @firestore.async_transactional async def create_in_transaction_helper( transaction, client, collection_id, cleanup, database @@ -3176,3 +3369,53 @@ async def in_transaction(transaction): # make sure we didn't skip assertions in inner function assert inner_fn_ran is True + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_query_in_transaction_with_read_time(client, cleanup, database): + """ + Test query profiling in transactions. + """ + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(5)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + await doc_refs[0].create({"a": 1, "b": 2}) + await doc_refs[1].create({"a": 1, "b": 1}) + + read_time = max( + [(await docref.get()).read_time for docref in doc_refs] + ) + await doc_refs[2].create({"a": 1, "b": 3}) + + collection = client.collection(collection_id) + query = collection.where(filter=FieldFilter("a", "==", 1)) + + # should work when transaction is initiated through transactional decorator + async with client.transaction() as transaction: + @firestore.async_transactional + async def in_transaction(transaction): + global inner_fn_ran + + new_b_values = [ + docs.get("b") async for docs in await transaction.get(query, read_time=read_time) + ] + assert len(new_b_values) == 2 + assert 1 in new_b_values + assert 2 in new_b_values + assert 3 not in new_b_values + + new_b_values = [ + docs.get("b") async for docs in await transaction.get(query) + ] + assert len(new_b_values) == 3 + assert 1 in new_b_values + assert 2 in new_b_values + assert 3 in new_b_values + + inner_fn_ran = True + + await in_transaction(transaction) + # make sure we didn't skip assertions in inner function + assert inner_fn_ran is True \ No newline at end of file diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 6254c4c87f..a2348807e2 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -321,9 +321,39 @@ def test_async_aggregation_query_prep_stream_with_explain_options(): assert kwargs == {"retry": None} +def test_async_aggregation_query_prep_stream_with_read_time(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") + + # 1800 seconds after epoch + read_time = datetime.now() + + request, kwargs = aggregation_query._prep_stream(read_time=read_time) + + parent_path, _ = parent._parent_info() + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + "read_time": read_time, + } + assert request == expected_request + assert kwargs == {"retry": None} + + @pytest.mark.asyncio async def _async_aggregation_query_get_helper( - retry=None, timeout=None, read_time=None, explain_options=None + retry=None, + timeout=None, + explain_options=None, + response_read_time=None, + query_read_time=None, ): from google.cloud._helpers import _datetime_to_pb_timestamp @@ -342,7 +372,11 @@ async def _async_aggregation_query_get_helper( aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") - aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time) + aggregation_result = AggregationResult( + alias="total", + value=5, + read_time=response_read_time, + ) if explain_options is not None: explain_metrics = {"execution_stats": {"results_returned": 1}} @@ -351,14 +385,18 @@ async def _async_aggregation_query_get_helper( response_pb = make_aggregation_query_response( [aggregation_result], - read_time=read_time, + read_time=response_read_time, explain_metrics=explain_metrics, ) firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - returned = await aggregation_query.get(**kwargs, explain_options=explain_options) + returned = await aggregation_query.get( + **kwargs, + explain_options=explain_options, + read_time=query_read_time, + ) assert isinstance(returned, QueryResultsList) assert len(returned) == 1 @@ -366,9 +404,9 @@ async def _async_aggregation_query_get_helper( for r in result: assert r.alias == aggregation_result.alias assert r.value == aggregation_result.value - if read_time is not None: + if response_read_time is not None: result_datetime = _datetime_to_pb_timestamp(r.read_time) - assert result_datetime == read_time + assert result_datetime == response_read_time if explain_options is None: with pytest.raises(QueryExplainError, match="explain_options not set"): @@ -387,6 +425,8 @@ async def _async_aggregation_query_get_helper( } if explain_options is not None: expected_request["explain_options"] = explain_options._to_dict() + if query_read_time is not None: + expected_request["read_time"] = query_read_time firestore_api.run_aggregation_query.assert_called_once_with( request=expected_request, metadata=client._rpc_metadata, @@ -405,7 +445,9 @@ async def test_async_aggregation_query_get_with_readtime(): one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) read_time = _datetime_to_pb_timestamp(one_hour_ago) - await _async_aggregation_query_get_helper(read_time=read_time) + await _async_aggregation_query_get_helper( + query_read_time=one_hour_ago, response_read_time=read_time + ) @pytest.mark.asyncio @@ -583,7 +625,9 @@ async def _async_aggregation_query_stream_helper( kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - returned = aggregation_query.stream(**kwargs, explain_options=explain_options) + returned = aggregation_query.stream( + **kwargs, explain_options=explain_options, read_time=read_time, + ) assert isinstance(returned, AsyncStreamGenerator) results = [] @@ -611,6 +655,8 @@ async def _async_aggregation_query_stream_helper( } if explain_options is not None: expected_request["explain_options"] = explain_options._to_dict() + if read_time is not None: + expected_request["read_time"] = read_time # Verify the mock call. firestore_api.run_aggregation_query.assert_called_once_with( @@ -625,6 +671,15 @@ async def test_aggregation_query_stream(): await _async_aggregation_query_stream_helper() +@pytest.mark.asyncio +async def test_async_aggregation_query_stream_with_read_time(): + from google.cloud._helpers import _datetime_to_pb_timestamp + + one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) + read_time = _datetime_to_pb_timestamp(one_hour_ago) + await _async_aggregation_query_stream_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_aggregation_query_stream_w_explain_options_analyze_true(): from google.cloud.firestore_v1.query_profile import ExplainOptions diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index ee624d382b..4924856a84 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -187,7 +187,7 @@ def test_asyncclient_document_factory_w_nested_path(): assert isinstance(document2, AsyncDocumentReference) -async def _collections_helper(retry=None, timeout=None): +async def _collections_helper(retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -206,7 +206,7 @@ async def __aiter__(self, **_): client._firestore_api_internal = firestore_api kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - collections = [c async for c in client.collections(**kwargs)] + collections = [c async for c in client.collections(read_time=read_time, **kwargs)] assert len(collections) == len(collection_ids) for collection, collection_id in zip(collections, collection_ids): @@ -215,8 +215,13 @@ async def __aiter__(self, **_): assert collection.id == collection_id base_path = client._database_string + "/documents" + expected_request = { + "parent": base_path, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": base_path}, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -236,6 +241,12 @@ async def test_asyncclient_collections_w_retry_timeout(): await _collections_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asyncclient_collections_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _collections_helper(read_time=read_time) + + async def _invoke_get_all(client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["batch_get_documents"]) @@ -252,7 +263,13 @@ async def _invoke_get_all(client, references, document_pbs, **kwargs): return [s async for s in snapshots] -async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None): +async def _get_all_helper( + num_snapshots=2, + txn_id=None, + retry=None, + timeout=None, + read_time=None, +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_document import DocumentSnapshot from google.cloud.firestore_v1.types import common @@ -261,13 +278,13 @@ async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None data1 = {"a": "cheese"} document1 = client.document("pineapple", "lamp1") - document_pb1, read_time = _doc_get_info(document1._document_path, data1) - response1 = _make_batch_response(found=document_pb1, read_time=read_time) + document_pb1, doc_read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=doc_read_time) data2 = {"b": True, "c": 18} document2 = client.document("pineapple", "lamp2") - document, read_time = _doc_get_info(document2._document_path, data2) - response2 = _make_batch_response(found=document, read_time=read_time) + document, doc_read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=doc_read_time) document3 = client.document("pineapple", "lamp3") response3 = _make_batch_response(missing=document3._document_path) @@ -290,6 +307,7 @@ async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None documents, responses, field_paths=field_paths, + read_time=read_time, **kwargs, ) @@ -308,14 +326,17 @@ async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None mask = common.DocumentMask(field_paths=field_paths) kwargs.pop("transaction", None) + expected_request = { + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": txn_id, + } + if read_time is not None: + expected_request["read_time"] = read_time client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": mask, - "transaction": txn_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -346,6 +367,12 @@ async def test_asyncclient_get_all_wrong_order(): await _get_all_helper(num_snapshots=3) +@pytest.mark.asyncio +async def test_asyncclient_get_all_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_all_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asyncclient_get_all_unknown_result(): from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 497fc455fa..214bda26f7 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -17,6 +17,7 @@ import mock import pytest +from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT, make_async_client from tests.unit.v1.test__helpers import AsyncIter, AsyncMock @@ -302,7 +303,7 @@ async def _get_chunk(*args, **kwargs): @pytest.mark.asyncio -async def _list_documents_helper(page_size=None, retry=None, timeout=None): +async def _list_documents_helper(page_size=None, retry=None, timeout=None, read_time=None): from google.api_core.page_iterator import Page from google.api_core.page_iterator_async import AsyncIterator @@ -340,10 +341,11 @@ async def _next_page(self): async for i in collection.list_documents( page_size=page_size, **kwargs, + read_time=read_time ) ] else: - documents = [i async for i in collection.list_documents(**kwargs)] + documents = [i async for i in collection.list_documents(**kwargs, read_time=read_time)] # Verify the response and the mocks. assert len(documents) == len(document_ids) @@ -353,14 +355,17 @@ async def _next_page(self): assert document.id == document_id parent, _ = collection._parent_info() + expected_request = { + "parent": parent, + "collection_id": collection.id, + "page_size": page_size, + "show_missing": True, + "mask": {"field_paths": None}, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.list_documents.assert_called_once_with( - request={ - "parent": parent, - "collection_id": collection.id, - "page_size": page_size, - "show_missing": True, - "mask": {"field_paths": None}, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -385,6 +390,11 @@ async def test_asynccollectionreference_list_documents_w_page_size(): await _list_documents_helper(page_size=25) +@pytest.mark.asyncio +async def test_asynccollectionreference_list_documents_w_read_time(): + _list_documents_helper(read_time=datetime.now(tz=timezone.utc)) + + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_asynccollectionreference_get(query_class): @@ -449,6 +459,20 @@ async def test_asynccollectionreference_get_w_explain_options(query_class): transaction=None, explain_options=explain_options ) +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_get_w_read_time(query_class): + read_time = datetime.now(tz=timezone.utc) + collection = _make_async_collection_reference("collection") + await collection.get(read_time=read_time) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.get.assert_called_once_with( + transaction=None, + read_time=read_time, + ) + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio @@ -552,6 +576,23 @@ async def response_generator(): assert explain_metrics.execution_stats.results_returned == 1 +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_stream_w_read_time(query_class): + read_time = datetime.now(tz=timezone.utc) + collection = _make_async_collection_reference("collection") + get_response = collection.stream(read_time=read_time) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + + assert get_response is query_instance.stream.return_value + query_instance.stream.assert_called_once_with( + transaction=None, + read_time=read_time, + ) + + def test_asynccollectionreference_recursive(): from google.cloud.firestore_v1.async_query import AsyncQuery diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 8d67e78f08..7d54f2e355 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -17,6 +17,9 @@ import mock import pytest +from datetime import datetime + +from google.protobuf import timestamp_pb2 from tests.unit.v1._test_helpers import make_async_client from tests.unit.v1.test__helpers import AsyncIter, AsyncMock @@ -399,6 +402,7 @@ async def _get_helper( return_empty=False, retry=None, timeout=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.transaction import Transaction @@ -407,10 +411,14 @@ async def _get_helper( # Create a minimal fake GAPIC with a dummy response. create_time = 123 update_time = 234 - read_time = 345 + if read_time: + response_read_time = timestamp_pb2.Timestamp() + response_read_time.FromDatetime(read_time) + else: + response_read_time = 345 firestore_api = AsyncMock(spec=["batch_get_documents"]) response = mock.create_autospec(firestore.BatchGetDocumentsResponse) - response.read_time = 345 + response.read_time = response_read_time response.found = mock.create_autospec(document.Document) response.found.fields = {} response.found.create_time = create_time @@ -445,6 +453,7 @@ def WhichOneof(val): field_paths=field_paths, transaction=transaction, **kwargs, + read_time=read_time, ) assert snapshot.reference is document_reference @@ -457,7 +466,7 @@ def WhichOneof(val): else: assert snapshot.to_dict() == {} assert snapshot.exists - assert snapshot.read_time is read_time + assert snapshot.read_time is response_read_time assert snapshot.create_time is create_time assert snapshot.update_time is update_time @@ -471,14 +480,18 @@ def WhichOneof(val): expected_transaction_id = transaction_id else: expected_transaction_id = None + + expected_request = { + "database": client._database_string, + "documents": [document_reference._document_path], + "mask": mask, + "transaction": expected_transaction_id, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": [document_reference._document_path], - "mask": mask, - "transaction": expected_transaction_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -530,7 +543,12 @@ async def test_asyncdocumentreference_get_with_transaction(): @pytest.mark.asyncio -async def _collections_helper(page_size=None, retry=None, timeout=None): +async def test_asyncdocumentreference_get_with_read_time(): + await _get_helper(read_time=datetime.now()) + + +@pytest.mark.asyncio +async def _collections_helper(page_size=None, retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -553,10 +571,10 @@ async def __aiter__(self, **_): document = _make_async_document_reference("where", "we-are", client=client) if page_size is not None: collections = [ - c async for c in document.collections(page_size=page_size, **kwargs) + c async for c in document.collections(page_size=page_size, **kwargs, read_time=read_time) ] else: - collections = [c async for c in document.collections(**kwargs)] + collections = [c async for c in document.collections(**kwargs, read_time=read_time)] # Verify the response and the mocks. assert len(collections) == len(collection_ids) @@ -564,9 +582,16 @@ async def __aiter__(self, **_): assert isinstance(collection, AsyncCollectionReference) assert collection.parent == document assert collection.id == collection_id + + expected_result = { + "parent": document._document_path, + "page_size": page_size, + } + if read_time is not None: + expected_result["read_time"] = read_time firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, + request=expected_result, metadata=client._rpc_metadata, **kwargs, ) @@ -586,6 +611,11 @@ async def test_asyncdocumentreference_collections_w_retry_timeout(): await _collections_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_documentreference_collections_w_read_time(): + await _collections_helper(read_time=datetime.now()) + + @pytest.mark.asyncio async def test_asyncdocumentreference_collections_w_page_size(): await _collections_helper(page_size=10) diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index efc6c7df78..5918766c82 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import types import mock @@ -41,7 +42,7 @@ def test_asyncquery_constructor(): assert not query._all_descendants -async def _get_helper(retry=None, timeout=None, explain_options=None): +async def _get_helper(retry=None, timeout=None, explain_options=None, read_time=None): from google.cloud.firestore_v1 import _helpers # Create a minimal fake GAPIC. @@ -68,7 +69,7 @@ async def _get_helper(retry=None, timeout=None, explain_options=None): # Execute the query and check the response. query = make_async_query(parent) - returned = await query.get(**kwargs, explain_options=explain_options) + returned = await query.get(**kwargs, explain_options=explain_options, read_time=read_time) assert isinstance(returned, QueryResultsList) assert len(returned) == 1 @@ -94,6 +95,8 @@ async def _get_helper(retry=None, timeout=None, explain_options=None): } if explain_options: request["explain_options"] = explain_options._to_dict() + if read_time: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -117,6 +120,12 @@ async def test_asyncquery_get_w_retry_timeout(): await _get_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asyncquery_get_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asyncquery_get_limit_to_last(): from google.cloud import firestore @@ -336,7 +345,7 @@ async def test_asyncquery_chunkify_w_chunksize_gt_limit(): assert [snapshot.id for snapshot in chunks[0]] == expected_ids -async def _stream_helper(retry=None, timeout=None, explain_options=None): +async def _stream_helper(retry=None, timeout=None, explain_options=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -367,7 +376,9 @@ async def _stream_helper(retry=None, timeout=None, explain_options=None): # Execute the query and check the response. query = make_async_query(parent) - stream_response = query.stream(**kwargs, explain_options=explain_options) + stream_response = query.stream( + **kwargs, explain_options=explain_options, read_time=read_time + ) assert isinstance(stream_response, AsyncStreamGenerator) returned = [x async for x in stream_response] @@ -395,6 +406,8 @@ async def _stream_helper(retry=None, timeout=None, explain_options=None): } if explain_options is not None: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -417,6 +430,10 @@ async def test_asyncquery_stream_w_retry_timeout(): timeout = 123.0 await _stream_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asyncquery_stream_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _stream_helper(read_time=read_time) @pytest.mark.asyncio async def test_asyncquery_stream_with_limit_to_last(): @@ -481,6 +498,57 @@ async def test_asyncquery_stream_with_transaction(): ) +@pytest.mark.asyncio +async def test_asyncquery_stream_with_transaction_and_read_time(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Create a real-ish transaction for this client. + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + # Create a read_time for this client. + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + + # Make a **real** collection reference as parent. + parent = client.collection("declaration") + + # Add a dummy response to the minimal fake GAPIC. + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = AsyncIter([response_pb]) + + # Execute the query and check the response. + query = make_async_query(parent) + get_response = query.stream(transaction=transaction, read_time=read_time) + assert isinstance(get_response, AsyncStreamGenerator) + returned = [x async for x in get_response] + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("declaration", "burger") + assert snapshot.to_dict() == data + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": txn_id, + "read_time": read_time, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio async def test_asyncquery_stream_no_results(): from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -718,7 +786,7 @@ def test_asynccollectiongroup_constructor_all_descendents_is_false(): @pytest.mark.asyncio -async def _get_partitions_helper(retry=None, timeout=None): +async def _get_partitions_helper(retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers # Create a minimal fake GAPIC. @@ -743,7 +811,7 @@ async def _get_partitions_helper(retry=None, timeout=None): # Execute the query and check the response. query = _make_async_collection_group(parent) - get_response = query.get_partitions(2, **kwargs) + get_response = query.get_partitions(2, read_time=read_time, **kwargs) assert isinstance(get_response, types.AsyncGeneratorType) returned = [i async for i in get_response] @@ -755,12 +823,15 @@ async def _get_partitions_helper(retry=None, timeout=None): parent, orders=(query._make_order("__name__", query.ASCENDING),), ) + expected_request = { + "parent": parent_path, + "structured_query": partition_query._to_protobuf(), + "partition_count": 2, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.partition_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": partition_query._to_protobuf(), - "partition_count": 2, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -780,6 +851,12 @@ async def test_asynccollectiongroup_get_partitions_w_retry_timeout(): await _get_partitions_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_partitions_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asynccollectiongroup_get_partitions_w_filter(): # Make a **real** collection reference as parent. diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index e4bb788e3d..70929e29d4 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import mock import pytest @@ -294,13 +295,15 @@ async def test_asynctransaction__commit_failure(): ) -async def _get_all_helper(retry=None, timeout=None): +async def _get_all_helper(retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers client = AsyncMock(spec=["get_all"]) transaction = _make_async_transaction(client) ref1, ref2 = mock.Mock(), mock.Mock() kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time result = await transaction.get_all([ref1, ref2], **kwargs) @@ -326,7 +329,13 @@ async def test_asynctransaction_get_all_w_retry_timeout(): await _get_all_helper(retry=retry, timeout=timeout) -async def _get_w_document_ref_helper(retry=None, timeout=None, explain_options=None): +@pytest.mark.asyncio +async def test_asynctransaction_get_all_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_all_helper(read_time=read_time) + + +async def _get_w_document_ref_helper(retry=None, timeout=None, explain_options=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_document import AsyncDocumentReference @@ -335,7 +344,12 @@ async def _get_w_document_ref_helper(retry=None, timeout=None, explain_options=N ref = AsyncDocumentReference("documents", "doc-id") kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - result = await transaction.get(ref, **kwargs, explain_options=explain_options) + if explain_options is not None: + kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time + + result = await transaction.get(ref, **kwargs) client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) assert result is client.get_all.return_value @@ -356,7 +370,7 @@ async def test_asynctransaction_get_w_document_ref_w_retry_timeout(): @pytest.mark.asyncio -async def test_transaction_get_w_document_ref_w_explain_options(): +async def test_asynctransaction_get_w_document_ref_w_explain_options(): from google.cloud.firestore_v1.query_profile import ExplainOptions with pytest.raises(ValueError, match="`explain_options` cannot be provided."): @@ -365,7 +379,14 @@ async def test_transaction_get_w_document_ref_w_explain_options(): ) -async def _get_w_query_helper(retry=None, timeout=None, explain_options=None): +@pytest.mark.asyncio +async def test_asynctransaction_get_w_document_ref_w_read_time(): + await _get_w_document_ref_helper( + read_time=datetime.datetime.now(tz=datetime.timezone.utc) + ) + + +async def _get_w_query_helper(retry=None, timeout=None, explain_options=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -407,6 +428,7 @@ async def _get_w_query_helper(retry=None, timeout=None, explain_options=None): query, **kwargs, explain_options=explain_options, + read_time=read_time, ) # Verify the response. @@ -435,6 +457,8 @@ async def _get_w_query_helper(retry=None, timeout=None, explain_options=None): } if explain_options is not None: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -462,6 +486,12 @@ async def test_transaction_get_w_query_w_explain_options(): await _get_w_query_helper(explain_options=ExplainOptions(analyze=True)) +@pytest.mark.asyncio +async def test_asynctransaction_get_w_query_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_w_query_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asynctransaction_get_failure(): client = _make_client() From 8764d63eed64c777c84b699265c5d8ab25a8f442 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 27 May 2025 14:46:03 +0000 Subject: [PATCH 2/4] used TYPE_CHECKING; fixed unit tests --- google/cloud/firestore_v1/async_aggregation.py | 5 ++--- google/cloud/firestore_v1/async_client.py | 4 ++-- google/cloud/firestore_v1/async_collection.py | 4 ++-- google/cloud/firestore_v1/async_query.py | 4 ++-- google/cloud/firestore_v1/async_transaction.py | 4 ++-- tests/unit/v1/test_async_collection.py | 2 +- 6 files changed, 11 insertions(+), 12 deletions(-) diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index 63a29f3447..b73b8b3e5d 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -20,8 +20,6 @@ """ from __future__ import annotations -import datetime - from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Union from google.api_core import gapic_v1 @@ -36,11 +34,12 @@ from google.cloud.firestore_v1.query_results import QueryResultsList if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.base_aggregation import AggregationResult from google.cloud.firestore_v1.query_profile import ExplainMetrics, ExplainOptions import google.cloud.firestore_v1.types.query_profile as query_profile_pb - class AsyncAggregationQuery(BaseAggregationQuery): """Represents an aggregation query to the Firestore API.""" diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index f675150ccd..3e302e7b09 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -25,8 +25,6 @@ """ from __future__ import annotations -import datetime - from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union from google.api_core import gapic_v1 @@ -51,6 +49,8 @@ ) if TYPE_CHECKING: + import datetime + from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 08a9ae516f..1b71372dd2 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -15,8 +15,6 @@ """Classes for representing collections for the Google Cloud Firestore API.""" from __future__ import annotations -import datetime - from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple from google.api_core import gapic_v1 @@ -36,6 +34,8 @@ from google.cloud.firestore_v1.document import DocumentReference if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index f8ebcbb920..4c636c6213 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,8 +20,6 @@ """ from __future__ import annotations -import datetime - from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Type from google.api_core import gapic_v1 @@ -42,6 +40,8 @@ from google.cloud.firestore_v1.query_results import QueryResultsList if TYPE_CHECKING: # pragma: NO COVER + import datetime + # Types needed only for Type Hints from google.cloud.firestore_v1.async_transaction import AsyncTransaction from google.cloud.firestore_v1.base_document import DocumentSnapshot diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 2e89f88d58..be8668cd62 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -15,8 +15,6 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" from __future__ import annotations -import datetime - from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional from google.api_core import exceptions, gapic_v1 @@ -38,6 +36,8 @@ # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 214bda26f7..47daebc92d 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -392,7 +392,7 @@ async def test_asynccollectionreference_list_documents_w_page_size(): @pytest.mark.asyncio async def test_asynccollectionreference_list_documents_w_read_time(): - _list_documents_helper(read_time=datetime.now(tz=timezone.utc)) + await _list_documents_helper(read_time=datetime.now(tz=timezone.utc)) @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) From 9612df4694260f4529794d70e7f4a5639c66ef49 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 27 May 2025 17:49:39 +0000 Subject: [PATCH 3/4] linting + fixing cover --- .../cloud/firestore_v1/async_aggregation.py | 4 +- google/cloud/firestore_v1/async_client.py | 4 +- google/cloud/firestore_v1/async_query.py | 2 +- tests/system/test_system_async.py | 38 ++++++++++++------- tests/unit/v1/test_async_aggregation.py | 4 +- tests/unit/v1/test_async_collection.py | 13 ++++--- tests/unit/v1/test_async_document.py | 13 +++++-- tests/unit/v1/test_async_query.py | 10 ++++- tests/unit/v1/test_async_transaction.py | 8 +++- 9 files changed, 64 insertions(+), 32 deletions(-) diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index b73b8b3e5d..e273f514ab 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -34,11 +34,11 @@ from google.cloud.firestore_v1.query_results import QueryResultsList if TYPE_CHECKING: # pragma: NO COVER - import datetime - from google.cloud.firestore_v1.base_aggregation import AggregationResult from google.cloud.firestore_v1.query_profile import ExplainMetrics, ExplainOptions import google.cloud.firestore_v1.types.query_profile as query_profile_pb + import datetime + class AsyncAggregationQuery(BaseAggregationQuery): """Represents an aggregation query to the Firestore API.""" diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 3e302e7b09..9169f02deb 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -48,10 +48,10 @@ grpc_asyncio as firestore_grpc_transport, ) -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER import datetime - from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER + from google.cloud.firestore_v1.bulk_writer import BulkWriter class AsyncClient(BaseClient): diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 4c636c6213..98de75bd63 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -472,7 +472,7 @@ def stream( retry=retry, timeout=timeout, explain_options=explain_options, - read_time=read_time + read_time=read_time, ) return AsyncStreamGenerator(inner_generator, explain_options) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 6ab22d6d8f..ed9984954d 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -261,7 +261,9 @@ async def test_collections_w_read_time(client, cleanup, database): assert first_collection_id in ids # We're just testing that we added one collection at read_time, not two. - collections = [x async for x in client.collections(retry=RETRIES, read_time=read_time)] + collections = [ + x async for x in client.collections(retry=RETRIES, read_time=read_time) + ] ids = [collection.id for collection in collections] assert second_collection_id not in ids assert first_collection_id in ids @@ -1148,9 +1150,16 @@ async def test_list_collections_with_read_time(client, cleanup, database): data2 = {"bar": "baz"} update_time2, document_ref2 = await collection.add(data2) cleanup(document_ref2.delete) - assert set([i async for i in collection.list_documents()]) == {document_ref1, document_ref2} - assert set([i async for i in collection.list_documents(read_time=update_time1)]) == {document_ref1} - assert set([i async for i in collection.list_documents(read_time=update_time2)]) == { + assert set([i async for i in collection.list_documents()]) == { + document_ref1, + document_ref2, + } + assert set( + [i async for i in collection.list_documents(read_time=update_time1)] + ) == {document_ref1} + assert set( + [i async for i in collection.list_documents(read_time=update_time2)] + ) == { document_ref1, document_ref2, } @@ -1506,8 +1515,8 @@ async def test_query_stream_w_read_time(query_docs, cleanup, database): # Compare query at read_time to query at current time. query = collection.where(filter=FieldFilter("b", "==", 1)) values = { - snapshot.id: snapshot.to_dict() async - for snapshot in query.stream(read_time=read_time) + snapshot.id: snapshot.to_dict() + async for snapshot in query.stream(read_time=read_time) } assert len(values) == num_vals assert new_ref.id not in values @@ -2033,8 +2042,11 @@ async def test_get_all(client, cleanup, database): await document2.create(new_data) await document3.update(new_data) - snapshots = [i async for i in - client.get_all([document1, document2, document3], read_time=read_time) + snapshots = [ + i + async for i in client.get_all( + [document1, document2, document3], read_time=read_time + ) ] assert snapshots[0].exists assert snapshots[1].exists @@ -3384,9 +3396,7 @@ async def test_query_in_transaction_with_read_time(client, cleanup, database): await doc_refs[0].create({"a": 1, "b": 2}) await doc_refs[1].create({"a": 1, "b": 1}) - read_time = max( - [(await docref.get()).read_time for docref in doc_refs] - ) + read_time = max([(await docref.get()).read_time for docref in doc_refs]) await doc_refs[2].create({"a": 1, "b": 3}) collection = client.collection(collection_id) @@ -3394,12 +3404,14 @@ async def test_query_in_transaction_with_read_time(client, cleanup, database): # should work when transaction is initiated through transactional decorator async with client.transaction() as transaction: + @firestore.async_transactional async def in_transaction(transaction): global inner_fn_ran new_b_values = [ - docs.get("b") async for docs in await transaction.get(query, read_time=read_time) + docs.get("b") + async for docs in await transaction.get(query, read_time=read_time) ] assert len(new_b_values) == 2 assert 1 in new_b_values @@ -3418,4 +3430,4 @@ async def in_transaction(transaction): await in_transaction(transaction) # make sure we didn't skip assertions in inner function - assert inner_fn_ran is True \ No newline at end of file + assert inner_fn_ran is True diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index a2348807e2..9140f53e81 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -626,7 +626,9 @@ async def _async_aggregation_query_stream_helper( # Execute the query and check the response. returned = aggregation_query.stream( - **kwargs, explain_options=explain_options, read_time=read_time, + **kwargs, + explain_options=explain_options, + read_time=read_time, ) assert isinstance(returned, AsyncStreamGenerator) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 47daebc92d..a0194ace5b 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -303,7 +303,9 @@ async def _get_chunk(*args, **kwargs): @pytest.mark.asyncio -async def _list_documents_helper(page_size=None, retry=None, timeout=None, read_time=None): +async def _list_documents_helper( + page_size=None, retry=None, timeout=None, read_time=None +): from google.api_core.page_iterator import Page from google.api_core.page_iterator_async import AsyncIterator @@ -339,13 +341,13 @@ async def _next_page(self): documents = [ i async for i in collection.list_documents( - page_size=page_size, - **kwargs, - read_time=read_time + page_size=page_size, **kwargs, read_time=read_time ) ] else: - documents = [i async for i in collection.list_documents(**kwargs, read_time=read_time)] + documents = [ + i async for i in collection.list_documents(**kwargs, read_time=read_time) + ] # Verify the response and the mocks. assert len(documents) == len(document_ids) @@ -459,6 +461,7 @@ async def test_asynccollectionreference_get_w_explain_options(query_class): transaction=None, explain_options=explain_options ) + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_asynccollectionreference_get_w_read_time(query_class): diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 7d54f2e355..45472c6604 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -480,7 +480,7 @@ def WhichOneof(val): expected_transaction_id = transaction_id else: expected_transaction_id = None - + expected_request = { "database": client._database_string, "documents": [document_reference._document_path], @@ -571,10 +571,15 @@ async def __aiter__(self, **_): document = _make_async_document_reference("where", "we-are", client=client) if page_size is not None: collections = [ - c async for c in document.collections(page_size=page_size, **kwargs, read_time=read_time) + c + async for c in document.collections( + page_size=page_size, **kwargs, read_time=read_time + ) ] else: - collections = [c async for c in document.collections(**kwargs, read_time=read_time)] + collections = [ + c async for c in document.collections(**kwargs, read_time=read_time) + ] # Verify the response and the mocks. assert len(collections) == len(collection_ids) @@ -582,7 +587,7 @@ async def __aiter__(self, **_): assert isinstance(collection, AsyncCollectionReference) assert collection.parent == document assert collection.id == collection_id - + expected_result = { "parent": document._document_path, "page_size": page_size, diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 5918766c82..54c80e5ad4 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -69,7 +69,9 @@ async def _get_helper(retry=None, timeout=None, explain_options=None, read_time= # Execute the query and check the response. query = make_async_query(parent) - returned = await query.get(**kwargs, explain_options=explain_options, read_time=read_time) + returned = await query.get( + **kwargs, explain_options=explain_options, read_time=read_time + ) assert isinstance(returned, QueryResultsList) assert len(returned) == 1 @@ -345,7 +347,9 @@ async def test_asyncquery_chunkify_w_chunksize_gt_limit(): assert [snapshot.id for snapshot in chunks[0]] == expected_ids -async def _stream_helper(retry=None, timeout=None, explain_options=None, read_time=None): +async def _stream_helper( + retry=None, timeout=None, explain_options=None, read_time=None +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -430,11 +434,13 @@ async def test_asyncquery_stream_w_retry_timeout(): timeout = 123.0 await _stream_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_asyncquery_stream_w_read_time(): read_time = datetime.datetime.now(tz=datetime.timezone.utc) await _stream_helper(read_time=read_time) + @pytest.mark.asyncio async def test_asyncquery_stream_with_limit_to_last(): # Attach the fake GAPIC to a real client. diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index 70929e29d4..d357e3482a 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -335,7 +335,9 @@ async def test_asynctransaction_get_all_w_read_time(): await _get_all_helper(read_time=read_time) -async def _get_w_document_ref_helper(retry=None, timeout=None, explain_options=None, read_time=None): +async def _get_w_document_ref_helper( + retry=None, timeout=None, explain_options=None, read_time=None +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_document import AsyncDocumentReference @@ -386,7 +388,9 @@ async def test_asynctransaction_get_w_document_ref_w_read_time(): ) -async def _get_w_query_helper(retry=None, timeout=None, explain_options=None, read_time=None): +async def _get_w_query_helper( + retry=None, timeout=None, explain_options=None, read_time=None +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator From a30926cbb2a9b1dc20a1f897e013d1b959aef955 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 27 May 2025 19:32:31 +0000 Subject: [PATCH 4/4] final linting --- google/cloud/firestore_v1/async_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 9169f02deb..15b31af314 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -48,7 +48,7 @@ grpc_asyncio as firestore_grpc_transport, ) -if TYPE_CHECKING: # pragma: NO COVER +if TYPE_CHECKING: # pragma: NO COVER import datetime from google.cloud.firestore_v1.bulk_writer import BulkWriter