From a7a0d4dadc0eb690d4cfab1851acfb0ef59a4480 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 May 2025 15:38:07 -0700 Subject: [PATCH 1/9] enable type checks --- noxfile.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/noxfile.py b/noxfile.py index 7ef3ed5b8..9e81d7179 100644 --- a/noxfile.py +++ b/noxfile.py @@ -155,9 +155,16 @@ def pytype(session): def mypy(session): """Verify type hints are mypy compatible.""" session.install("-e", ".") - session.install("mypy", "types-setuptools") - # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.install("mypy", "types-setuptools", "types-protobuf") + session.run( + "mypy", + "-p", + "google.cloud.firestore_v1", + "--no-incremental", + "--check-untyped-defs", + "--exclude", + "services", + ) @nox.session(python=DEFAULT_PYTHON_VERSION) From 13e7e1c3a50ed21f4dffd92f1c28e9a011d7964b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 May 2025 16:15:28 -0700 Subject: [PATCH 2/9] fixed mypy issues --- google/cloud/firestore_v1/async_collection.py | 4 +-- google/cloud/firestore_v1/base_aggregation.py | 23 ++++++------ google/cloud/firestore_v1/base_collection.py | 16 +++++---- google/cloud/firestore_v1/bulk_writer.py | 3 +- google/cloud/firestore_v1/watch.py | 35 ++++++++++--------- 5 files changed, 45 insertions(+), 36 deletions(-) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 8c832b8f4..75db41c5e 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -31,9 +31,9 @@ BaseCollectionReference, _item_to_document_ref, ) -from google.cloud.firestore_v1.document import DocumentReference if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions @@ -162,7 +162,7 @@ async def list_documents( page_size: int | None = None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, - ) -> AsyncGenerator[DocumentReference, None]: + ) -> AsyncGenerator[AsyncDocumentReference, None]: """List all subdocuments of the current collection. Args: diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 34a3baad8..98fa95153 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -80,23 +80,24 @@ def __init__(self, alias: str | None = None): def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() return aggregation_pb class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): - if isinstance(field_ref, FieldPath): - # convert field path to string - field_ref = field_ref.to_api_repr() - self.field_ref = field_ref + # convert field path to string if needed + field_str = field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + self.field_ref: str = field_str super(SumAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() aggregation_pb.sum.field.field_path = self.field_ref return aggregation_pb @@ -104,16 +105,16 @@ def _to_protobuf(self): class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): - if isinstance(field_ref, FieldPath): - # convert field path to string - field_ref = field_ref.to_api_repr() - self.field_ref = field_ref + # convert field path to string if needed + field_str = field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + self.field_ref: str = field_str super(AvgAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() aggregation_pb.avg.field.field_path = self.field_ref return aggregation_pb diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index b113da827..416999c76 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -43,7 +43,7 @@ BaseVectorQuery, DistanceMeasure, ) - from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList @@ -128,7 +128,7 @@ def _aggregation_query(self) -> BaseAggregationQuery: def _vector_query(self) -> BaseVectorQuery: raise NotImplementedError - def document(self, document_id: Optional[str] = None) -> DocumentReference: + def document(self, document_id: Optional[str] = None): """Create a sub-document underneath the current collection. Args: @@ -138,7 +138,7 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: uppercase and lowercase and letters. Returns: - :class:`~google.cloud.firestore_v1.document.DocumentReference`: + :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: The child document. """ if document_id is None: @@ -178,7 +178,7 @@ def _prep_add( document_id: Optional[str] = None, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, - ) -> Tuple[DocumentReference, dict]: + ): """Shared setup for async / sync :method:`add`""" if document_id is None: document_id = _auto_id() @@ -225,7 +225,7 @@ def list_documents( retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, ) -> Union[ - Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] + Generator[BaseDocumentReference, Any, Any], AsyncGenerator[BaseDocumentReference, Any] ]: raise NotImplementedError @@ -601,13 +601,17 @@ def _auto_id() -> str: return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20)) -def _item_to_document_ref(collection_reference, item) -> DocumentReference: +def _item_to_document_ref(collection_reference, item): """Convert Document resource to document ref. Args: collection_reference (google.api_core.page_iterator.GRPCIterator): iterator response item (dict): document resource + + Returns: + :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: + The child document """ document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] return collection_reference.document(document_id) diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index eff936300..6747bc234 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -110,7 +110,7 @@ def wrapper(self, *args, **kwargs): # For code parity, even `SendMode.serial` scenarios should return # a future here. Anything else would badly complicate calling code. result = fn(self, *args, **kwargs) - future = concurrent.futures.Future() + future: concurrent.futures.Future = concurrent.futures.Future() future.set_result(result) return future @@ -319,6 +319,7 @@ def __init__( self._total_batches_sent: int = 0 self._total_write_operations: int = 0 + self._executor: concurrent.futures.ThreadPoolExecutor self._ensure_executor() @staticmethod diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index 79933aeca..10c37db02 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections import functools @@ -232,7 +233,7 @@ def __init__( def _init_stream(self): rpc_request = self._get_rpc_request - self._rpc = ResumableBidiRpc( + self._rpc: ResumableBidiRpc | None = ResumableBidiRpc( start_rpc=self._api._transport.listen, should_recover=_should_recover, should_terminate=_should_terminate, @@ -243,7 +244,7 @@ def _init_stream(self): self._rpc.add_done_callback(self._on_rpc_done) # The server assigns and updates the resume token. - self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot) + self._consumer: BackgroundConsumer | None = BackgroundConsumer(self._rpc, self.on_snapshot) self._consumer.start() @classmethod @@ -330,16 +331,18 @@ def close(self, reason=None): return # Stop consuming messages. - if self.is_active: - _LOGGER.debug("Stopping consumer.") - self._consumer.stop() - self._consumer._on_response = None + if self._consumer: + if self.is_active: + _LOGGER.debug("Stopping consumer.") + self._consumer.stop() + self._consumer._on_response = None self._consumer = None self._snapshot_callback = None - self._rpc.close() - self._rpc._initial_request = None - self._rpc._callbacks = [] + if self._rpc: + self._rpc.close() + self._rpc._initial_request = None + self._rpc._callbacks = [] self._rpc = None self._closed = True _LOGGER.debug("Finished stopping manager.") @@ -460,13 +463,13 @@ def on_snapshot(self, proto): message = f"Unknown target change type: {target_change_type}" _LOGGER.info(f"on_snapshot: {message}") self.close(reason=ValueError(message)) - - try: - # Use 'proto' vs 'pb' for datetime handling - meth(self, proto.target_change) - except Exception as exc2: - _LOGGER.debug(f"meth(proto) exc: {exc2}") - raise + else: + try: + # Use 'proto' vs 'pb' for datetime handling + meth(self, proto.target_change) + except Exception as exc2: + _LOGGER.debug(f"meth(proto) exc: {exc2}") + raise # NOTE: # in other implementations, such as node, the backoff is reset here From 230e0a3cd5dcea7204eecd505728d66cca461063 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 May 2025 16:20:04 -0700 Subject: [PATCH 3/9] fixed lint --- google/cloud/firestore_v1/base_aggregation.py | 8 ++++++-- google/cloud/firestore_v1/base_collection.py | 6 ++++-- google/cloud/firestore_v1/watch.py | 4 +++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 98fa95153..41efbcbac 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -89,7 +89,9 @@ def _to_protobuf(self): class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): # convert field path to string if needed - field_str = field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + field_str = ( + field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + ) self.field_ref: str = field_str super(SumAggregation, self).__init__(alias=alias) @@ -106,7 +108,9 @@ def _to_protobuf(self): class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): # convert field path to string if needed - field_str = field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + field_str = ( + field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + ) self.field_ref: str = field_str super(AvgAggregation, self).__init__(alias=alias) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 416999c76..6094fad24 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -43,7 +43,8 @@ BaseVectorQuery, DistanceMeasure, ) - from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList @@ -225,7 +226,8 @@ def list_documents( retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, ) -> Union[ - Generator[BaseDocumentReference, Any, Any], AsyncGenerator[BaseDocumentReference, Any] + Generator[DocumentReference, Any, Any], + AsyncGenerator[AsyncDocumentReference, Any], ]: raise NotImplementedError diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index 10c37db02..971485655 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -244,7 +244,9 @@ def _init_stream(self): self._rpc.add_done_callback(self._on_rpc_done) # The server assigns and updates the resume token. - self._consumer: BackgroundConsumer | None = BackgroundConsumer(self._rpc, self.on_snapshot) + self._consumer: BackgroundConsumer | None = BackgroundConsumer( + self._rpc, self.on_snapshot + ) self._consumer.start() @classmethod From b293956a9a137ca50a6875914184d1813e550558 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 May 2025 16:46:13 -0700 Subject: [PATCH 4/9] updated reference --- google/cloud/firestore_v1/async_collection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 75db41c5e..82f1ba3d0 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -22,7 +22,6 @@ from google.cloud.firestore_v1 import ( async_aggregation, - async_document, async_query, async_vector_query, transaction, @@ -142,7 +141,7 @@ async def add( def document( self, document_id: str | None = None - ) -> async_document.AsyncDocumentReference: + ) -> AsyncDocumentReference: """Create a sub-document underneath the current collection. Args: From a88cccbcceb6a72a21f9209ba32c19d401f1dc8b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 May 2025 16:56:02 -0700 Subject: [PATCH 5/9] fixed typing for FieldPath and FieldFilter --- google/cloud/firestore_v1/async_collection.py | 4 +--- google/cloud/firestore_v1/base_query.py | 10 +++++----- google/cloud/firestore_v1/field_path.py | 14 +++++++------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 82f1ba3d0..323c802fa 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -139,9 +139,7 @@ async def add( write_result = await document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref - def document( - self, document_id: str | None = None - ) -> AsyncDocumentReference: + def document(self, document_id: str | None = None) -> AsyncDocumentReference: """Create a sub-document underneath the current collection. Args: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 5a9efaf78..2c81b47bd 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -178,7 +178,7 @@ def _validate_opation(op_string, value): class FieldFilter(BaseFilter): """Class representation of a Field Filter.""" - def __init__(self, field_path, op_string, value=None): + def __init__(self, field_path: str, op_string: str, value: Any | None = None): self.field_path = field_path self.value = value self.op_string = _validate_opation(op_string, value) @@ -204,8 +204,8 @@ class BaseCompositeFilter(BaseFilter): def __init__( self, - operator=StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, - filters=None, + operator: int = StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, + filters: list[BaseFilter] | None = None, ): self.operator = operator if filters is None: @@ -237,7 +237,7 @@ def _to_pb(self): class Or(BaseCompositeFilter): """Class representation of an OR Filter.""" - def __init__(self, filters): + def __init__(self, filters: list[BaseFilter]): super().__init__( operator=StructuredQuery.CompositeFilter.Operator.OR, filters=filters ) @@ -246,7 +246,7 @@ def __init__(self, filters): class And(BaseCompositeFilter): """Class representation of an AND Filter.""" - def __init__(self, filters): + def __init__(self, filters: list[BaseFilter]): super().__init__( operator=StructuredQuery.CompositeFilter.Operator.AND, filters=filters ) diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 048eb64d0..27ac6cc45 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -263,7 +263,7 @@ class FieldPath(object): Indicating path of the key to be used. """ - def __init__(self, *parts): + def __init__(self, *parts: str): for part in parts: if not isinstance(part, str) or not part: error = "One or more components is not a string or is empty." @@ -271,7 +271,7 @@ def __init__(self, *parts): self.parts = tuple(parts) @classmethod - def from_api_repr(cls, api_repr: str): + def from_api_repr(cls, api_repr: str) -> "FieldPath": """Factory: create a FieldPath from the string formatted per the API. Args: @@ -288,7 +288,7 @@ def from_api_repr(cls, api_repr: str): return cls(*parse_field_path(api_repr)) @classmethod - def from_string(cls, path_string: str): + def from_string(cls, path_string: str) -> "FieldPath": """Factory: create a FieldPath from a unicode string representation. This method splits on the character `.` and disallows the @@ -351,7 +351,7 @@ def __add__(self, other): else: return NotImplemented - def to_api_repr(self): + def to_api_repr(self) -> str: """Render a quoted string representation of the FieldPath Returns: @@ -360,7 +360,7 @@ def to_api_repr(self): """ return render_field_path(self.parts) - def eq_or_parent(self, other): + def eq_or_parent(self, other) -> bool: """Check whether ``other`` is an ancestor. Returns: @@ -369,7 +369,7 @@ def eq_or_parent(self, other): """ return self.parts[: len(other.parts)] == other.parts[: len(self.parts)] - def lineage(self): + def lineage(self) -> set["FieldPath"]: """Return field paths for all parents. Returns: Set[:class:`FieldPath`] @@ -378,7 +378,7 @@ def lineage(self): return {FieldPath(*self.parts[:index]) for index in indexes} @staticmethod - def document_id(): + def document_id() -> str: """A special FieldPath value to refer to the ID of a document. It can be used in queries to sort or filter by the document ID. From 78e2266628c9bbad618c15f23046b1d59be780b5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 May 2025 17:09:32 -0700 Subject: [PATCH 6/9] fixed cover --- tests/unit/v1/test_aggregation.py | 20 ++++++++++++++++++++ tests/unit/v1/test_watch.py | 11 +++++++++++ 2 files changed, 31 insertions(+) diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 4d1eed198..7eaeabb38 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -49,6 +49,12 @@ def test_count_aggregation_to_pb(): assert count_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_count_aggregation_no_alias_to_pb(): + count_aggregation = CountAggregation(alias=None) + got_pb = count_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_sum_aggregation_w_field_path(): """ SumAggregation should convert FieldPath inputs into strings @@ -86,6 +92,12 @@ def test_sum_aggregation_to_pb(): assert sum_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_sum_aggregation_no_alias_to_pb(): + sum_aggregation = SumAggregation("someref", alias=None) + got_pb = sum_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_avg_aggregation_to_pb(): from google.cloud.firestore_v1.types import query as query_pb2 @@ -101,6 +113,14 @@ def test_avg_aggregation_to_pb(): assert avg_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_avg_aggregation_no_alias_to_pb(): + from google.cloud.firestore_v1.types import query as query_pb2 + + avg_aggregation = AvgAggregation("someref", alias=None) + got_pb = avg_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 6d8c12abc..e125f45f1 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -322,6 +322,17 @@ def test_watch_close(): assert inst._closed +def test_watch_double_close(): + """ + Calling close twice should succeed with no error + """ + inst = _make_watch() + inst.close() + inst.close() + assert inst._consumer is None + assert inst._rpc is None + + def test_watch__get_rpc_request_wo_resume_token(): inst = _make_watch() From edb8b17b0eb29e35a91c5e77f64be67f6fa9a644 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 May 2025 17:27:24 -0700 Subject: [PATCH 7/9] fix cover --- tests/unit/v1/test_aggregation.py | 2 -- tests/unit/v1/test_watch.py | 8 +++----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 7eaeabb38..1764e6a4f 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -114,8 +114,6 @@ def test_avg_aggregation_to_pb(): def test_avg_aggregation_no_alias_to_pb(): - from google.cloud.firestore_v1.types import query as query_pb2 - avg_aggregation = AvgAggregation("someref", alias=None) got_pb = avg_aggregation._to_protobuf() assert got_pb.alias == "" diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index e125f45f1..63e2233a4 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -322,12 +322,10 @@ def test_watch_close(): assert inst._closed -def test_watch_double_close(): - """ - Calling close twice should succeed with no error - """ +def test_watch_close_w_empty_attrs(): inst = _make_watch() - inst.close() + inst._consumer = None + inst._rpc = None inst.close() assert inst._consumer is None assert inst._rpc is None From 0af6a11a9b6c939c154b183f91339ff4a8f88ed6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 23 May 2025 15:43:44 -0700 Subject: [PATCH 8/9] fixed vector search type --- google/cloud/firestore_v1/async_query.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index d4fd45fa4..246d62e8a 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,7 +20,16 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Type +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + List, + Optional, + Type, + Union, + Sequence, +) from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -248,7 +257,7 @@ async def get( def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, *, @@ -261,7 +270,7 @@ def find_nearest( Args: vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector (Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. From 56466b42264da4678b9dc22ad514eef77b84b757 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Jun 2025 14:19:57 -0700 Subject: [PATCH 9/9] fixed mypy --- google/cloud/firestore_v1/base_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index acbd148fb..4a0e3f6b8 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -476,7 +476,7 @@ def _prep_collections( read_time: datetime.datetime | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" - request = { + request: dict[str, Any] = { "parent": "{}/documents".format(self._database_string), } if read_time is not None: