diff --git a/dev/conformance/runner.ts b/dev/conformance/runner.ts index f333fd933..fde0a6dcf 100644 --- a/dev/conformance/runner.ts +++ b/dev/conformance/runner.ts @@ -272,7 +272,10 @@ function commitHandler( updateTime: {}, }); } - return response(res); + const promise = response(res) as any; + // Add cancel method to make it a CancellablePromise + promise.cancel = () => {}; + return promise; }; } @@ -284,7 +287,9 @@ function queryHandler(spec: ConformanceProto) { ); const expectedQuery = STRUCTURED_QUERY_TYPE.fromObject(spec.query); expect(actualQuery).to.deep.equal(expectedQuery); - const stream = through2.obj(); + const stream = through2.obj() as any; + // Add cancel method to make it a CancellableStream + stream.cancel = () => {}; setImmediate(() => { // Empty query always emits a readTime stream.push({readTime: {seconds: 0, nanos: 0}}); @@ -299,7 +304,9 @@ function getHandler(spec: ConformanceProto) { return (request: api.IBatchGetDocumentsRequest) => { const getDocument = spec.request; expect(request.documents![0]).to.equal(getDocument.name); - const stream = through2.obj(); + const stream = through2.obj() as any; + // Add cancel method to make it a CancellableStream + stream.cancel = () => {}; setImmediate(() => { stream.push({ missing: getDocument.name, @@ -432,7 +439,12 @@ function runTest(spec: ConformanceProto) { const expectedSnapshots = spec.snapshots; const writeStream = through2.obj(); const overrides: ApiOverride = { - listen: () => duplexify.obj(through2.obj(), writeStream), + listen: () => { + const stream = duplexify.obj(through2.obj(), writeStream) as any; + // Add cancel method to make it a CancellableStream + stream.cancel = () => {}; + return stream; + }, }; return createInstance(overrides).then(() => { diff --git a/dev/src/abort-util.ts b/dev/src/abort-util.ts new file mode 100644 index 000000000..a47768e12 --- /dev/null +++ b/dev/src/abort-util.ts @@ -0,0 +1,79 @@ +/*! + * Copyright 2024 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Interface for objects that can be cancelled. + */ +export interface Cancellable { + cancel(): void; +} + +/** + * Utility class for working with AbortSignal and cancellable operations. + */ +export class AbortUtil { + /** + * Throws an error if the AbortSignal is already aborted. + */ + static throwIfAborted(signal: AbortSignal | null): void { + if (signal?.aborted) { + throw new Error('The operation was aborted'); + } + } + + /** + * Creates a Promise that rejects when the AbortSignal is aborted. + */ + static createAbortPromise( + signal: AbortSignal, + cancellable: Cancellable + ): Promise { + return new Promise((_, reject) => { + const onAbort = () => { + cancellable.cancel(); + reject(new Error('The operation was aborted')); + }; + + if (signal.aborted) { + onAbort(); + } else { + signal.addEventListener('abort', onAbort, { once: true }); + } + }); + } + + /** + * Makes a Promise cancellable with an AbortSignal. + */ + static async makeCancellable( + promise: Promise, + cancellable: Cancellable, + signal: AbortSignal | null + ): Promise { + if (!signal) { + return promise; + } + + // Check if already aborted + AbortUtil.throwIfAborted(signal); + + // Race the original promise against the abort promise + return Promise.race([ + promise, + AbortUtil.createAbortPromise(signal, cancellable) + ]); + } +} \ No newline at end of file diff --git a/dev/src/bulk-writer.ts b/dev/src/bulk-writer.ts index 1f8371624..1ada2f8ed 100644 --- a/dev/src/bulk-writer.ts +++ b/dev/src/bulk-writer.ts @@ -246,7 +246,7 @@ class BulkCommitBatch extends WriteBatch { return this.docPaths.has(documentRef.path); } - async bulkCommit(options: {requestTag?: string} = {}): Promise { + async bulkCommit(options: {requestTag?: string, abortSignal?: AbortSignal} = {}): Promise { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_BULK_WRITER_COMMIT, async () => { @@ -266,7 +266,7 @@ class BulkCommitBatch extends WriteBatch { response = await this._commit< api.BatchWriteRequest, api.BatchWriteResponse - >({retryCodes, methodName: 'batchWrite', requestTag: tag}); + >(options, {retryCodes, methodName: 'batchWrite', requestTag: tag}); //TODO: add public options } catch (err) { // Map the failure to each individual write's result. const ops = Array.from({length: this.pendingOps.length}); diff --git a/dev/src/collection-group.ts b/dev/src/collection-group.ts index 92ca05879..f15c7f2c5 100644 --- a/dev/src/collection-group.ts +++ b/dev/src/collection-group.ts @@ -80,7 +80,8 @@ export class CollectionGroup< * `QueryPartition`s. */ async *getPartitions( - desiredPartitionCount: number + desiredPartitionCount: number, + options?: firestore.FirestoreRequestOptions ): AsyncIterable> { const partitions: Array[] = []; @@ -108,7 +109,8 @@ export class CollectionGroup< 'partitionQueryStream', /* bidirectional= */ false, request, - tag + tag, + options ); stream.resume(); diff --git a/dev/src/document-reader.ts b/dev/src/document-reader.ts index f1c238a63..a695587fc 100644 --- a/dev/src/document-reader.ts +++ b/dev/src/document-reader.ts @@ -22,7 +22,7 @@ import {google} from '../protos/firestore_v1_proto_api'; import {logger} from './logger'; import {Firestore} from './index'; import {Timestamp} from './timestamp'; -import {DocumentData} from '@google-cloud/firestore'; +import {DocumentData, FirestoreRequestOptions} from '@google-cloud/firestore'; import api = google.firestore.v1; interface BatchGetResponse { @@ -79,9 +79,10 @@ export class DocumentReader { * @param requestTag A unique client-assigned identifier for this request. */ async get( - requestTag: string + requestTag: string, + options?: FirestoreRequestOptions ): Promise>> { - const {result} = await this._get(requestTag); + const {result} = await this._get(requestTag, options); //TODO: add public options return result; } @@ -92,9 +93,10 @@ export class DocumentReader { * @param requestTag A unique client-assigned identifier for this request. */ async _get( - requestTag: string + requestTag: string, + options: FirestoreRequestOptions | undefined ): Promise> { - await this.fetchDocuments(requestTag); + await this.fetchDocuments(requestTag, options); // BatchGetDocuments doesn't preserve document order. We use the request // order to sort the resulting documents. @@ -125,7 +127,7 @@ export class DocumentReader { }; } - private async fetchDocuments(requestTag: string): Promise { + private async fetchDocuments(requestTag: string, options: FirestoreRequestOptions | undefined): Promise { if (!this.outstandingDocuments.size) { return; } @@ -156,7 +158,8 @@ export class DocumentReader { 'batchGetDocuments', /* bidirectional= */ false, request, - requestTag + requestTag, + options ); stream.resume(); @@ -217,7 +220,7 @@ export class DocumentReader { shouldRetry ); if (shouldRetry) { - return this.fetchDocuments(requestTag); + return this.fetchDocuments(requestTag, options); } else { throw error; } diff --git a/dev/src/index.ts b/dev/src/index.ts index 4f093022c..b99e53172 100644 --- a/dev/src/index.ts +++ b/dev/src/index.ts @@ -74,6 +74,7 @@ import { validateTimestamp, } from './validate'; import {WriteBatch} from './write-batch'; +import {AbortUtil, Cancellable} from './abort-util'; import {interfaces} from './v1/firestore_client_config.json'; const serviceConfig = interfaces['google.firestore.v1.Firestore']; @@ -1288,9 +1289,9 @@ export class Firestore implements firestore.Firestore { * }); * ``` */ - listCollections(): Promise { + listCollections(options?: firestore.FirestoreRequestOptions): Promise { const rootDocument = new DocumentReference(this, ResourcePath.EMPTY); - return rootDocument.listCollections(); + return rootDocument.listCollections(options); } /** @@ -1331,9 +1332,10 @@ export class Firestore implements firestore.Firestore { 1 ); - const {documents, fieldMask} = parseGetAllArguments( + const { parsedGetAllArguments, readOptions } = parseGetAllArguments( documentRefsOrReadOptions ); + const { documents, fieldMask } = parsedGetAllArguments; this._traceUtil.currentSpan().setAttributes({ [ATTRIBUTE_KEY_IS_TRANSACTIONAL]: false, @@ -1348,7 +1350,7 @@ export class Firestore implements firestore.Firestore { return this.initializeIfNeeded(tag) .then(() => { const reader = new DocumentReader(this, documents, fieldMask); - return reader.get(tag); + return reader.get(tag, readOptions); }) .catch(err => { throw wrapError(err, stack); @@ -1793,11 +1795,17 @@ export class Firestore implements firestore.Firestore { methodName: FirestoreUnaryMethod, request: Req, requestTag: string, + options: firestore.FirestoreRequestOptions | undefined, retryCodes?: number[] ): Promise { + const abortSignal = options?.abortSignal || null; + + // Check if already aborted before starting + AbortUtil.throwIfAborted(abortSignal); + const callOptions = this.createCallOptions(methodName, retryCodes); - return this._clientPool.run( + const requestPromise = this._clientPool.run( requestTag, /* requiresGrpc= */ false, async gapicClient => { @@ -1808,9 +1816,21 @@ export class Firestore implements firestore.Firestore { 'Sending request: %j', request ); - const [result] = await ( - gapicClient[methodName] as UnaryMethod - )(request, callOptions); + + // Make the GAX call - this returns a CancellablePromise + const gaxCall = (gapicClient[methodName] as UnaryMethod)(request, callOptions); + + // Create cancellable wrapper for the GAX call + const cancellable: Cancellable = { + cancel: () => { + logger('Firestore.request', requestTag, 'Cancelling request due to abort signal'); + gaxCall.cancel(); + } + }; + + // Make the call cancellable with AbortSignal + const [result] = await AbortUtil.makeCancellable(gaxCall, cancellable, abortSignal); + logger( 'Firestore.request', requestTag, @@ -1824,6 +1844,8 @@ export class Firestore implements firestore.Firestore { } } ); + + return requestPromise; } /** @@ -1847,8 +1869,14 @@ export class Firestore implements firestore.Firestore { methodName: FirestoreStreamingMethod, bidrectional: boolean, request: {}, - requestTag: string + requestTag: string, + options: firestore.FirestoreRequestOptions | undefined ): Promise { + const abortSignal = options?.abortSignal || null; + + // Check if already aborted before starting + AbortUtil.throwIfAborted(abortSignal); + const callOptions = this.createCallOptions(methodName); const bidirectional = methodName === 'listen'; @@ -1874,6 +1902,15 @@ export class Firestore implements firestore.Firestore { const stream = bidirectional ? gapicClient[methodName](callOptions) : gapicClient[methodName](request, callOptions); + + // Create a cancellable object for the stream + const cancellable: Cancellable = { + cancel: () => { + logger('Firestore.requestStream', requestTag, 'Cancelling stream due to abort signal'); + (stream as any).cancel(); + } + }; + const logStream = new Transform({ objectMode: true, transform: (chunk, encoding, callback) => { @@ -1901,12 +1938,20 @@ export class Firestore implements firestore.Firestore { stream.pipe(logStream); const lifetime = new Deferred(); - const resultStream = await this._initializeStream( + const streamPromise = this._initializeStream( stream, lifetime, requestTag, bidirectional ? request : undefined ); + + // Make the stream initialization cancellable + const resultStream = await AbortUtil.makeCancellable( + streamPromise, + cancellable, + abortSignal + ); + resultStream.on('end', () => { stream.end(); this._traceUtil diff --git a/dev/src/reference/aggregate-query.ts b/dev/src/reference/aggregate-query.ts index 5a78f4d5e..8ca212555 100644 --- a/dev/src/reference/aggregate-query.ts +++ b/dev/src/reference/aggregate-query.ts @@ -34,6 +34,7 @@ import { SPAN_NAME_AGGREGATION_QUERY_GET, SPAN_NAME_RUN_AGGREGATION_QUERY, } from '../telemetry/trace-util'; +import { request } from 'http'; /** * A query that calculates aggregations over an underlying query. @@ -82,13 +83,13 @@ export class AggregateQuery< * * @return A promise that will be resolved with the results of the query. */ - async get(): Promise< + async get(options?: firestore.FirestoreRequestOptions): Promise< AggregateQuerySnapshot > { return this._query._firestore._traceUtil.startActiveSpan( SPAN_NAME_AGGREGATION_QUERY_GET, async () => { - const {result} = await this._get(); + const {result} = await this._get(options); return result; } ); @@ -104,13 +105,14 @@ export class AggregateQuery< * transaction, or timestamp to use as read time. */ async _get( + options: firestore.FirestoreRequestOptions | undefined, transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions ): Promise< QuerySnapshotResponse< AggregateQuerySnapshot > > { - const response = await this._getResponse(transactionOrReadTime); + const response = await this._getResponse(options, transactionOrReadTime); if (!response.result) { throw new Error('No AggregateQuery results'); } @@ -129,6 +131,7 @@ export class AggregateQuery< * transaction, or timestamp to use as read time. */ _getResponse( + options: firestore.FirestoreRequestOptions | undefined, transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions, explainOptions?: firestore.ExplainOptions ): Promise< @@ -144,7 +147,7 @@ export class AggregateQuery< AggregateQuerySnapshot > = {}; - const stream = this._stream(transactionOrReadTime, explainOptions); + const stream = this._stream(options, transactionOrReadTime, explainOptions); stream.on('error', err => { reject(wrapError(err, stack)); }); @@ -187,6 +190,7 @@ export class AggregateQuery< * @returns A stream of document results optionally preceded by a transaction response. */ _stream( + options: firestore.FirestoreRequestOptions | undefined, transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions, explainOptions?: firestore.ExplainOptions ): Readable { @@ -234,7 +238,8 @@ export class AggregateQuery< 'runAggregationQuery', /* bidirectional= */ false, request, - tag + tag, + options ); stream.on('close', () => { backendStream.resume(); @@ -387,13 +392,15 @@ export class AggregateQuery< * statistics from the query execution (if any), and the query results (if any). */ async explain( - options?: firestore.ExplainOptions + options?: firestore.ExplainOptions, + requestOptions?: firestore.FirestoreRequestOptions ): Promise< ExplainResults< AggregateQuerySnapshot > > { const {result, explainMetrics} = await this._getResponse( + requestOptions, undefined, options || {} ); diff --git a/dev/src/reference/collection-reference.ts b/dev/src/reference/collection-reference.ts index 3f62851b1..7ac12e0a1 100644 --- a/dev/src/reference/collection-reference.ts +++ b/dev/src/reference/collection-reference.ts @@ -165,7 +165,7 @@ export class CollectionReference< * }); * ``` */ - listDocuments(): Promise< + listDocuments(options?: firestore.FirestoreRequestOptions): Promise< Array> > { return this._firestore._traceUtil.startActiveSpan( @@ -190,7 +190,7 @@ export class CollectionReference< .request< api.IListDocumentsRequest, api.IDocument[] - >('listDocuments', request, tag) + >('listDocuments', request, tag, options) .then(documents => { // Note that the backend already orders these documents by name, // so we do not need to manually sort them. diff --git a/dev/src/reference/document-reference.ts b/dev/src/reference/document-reference.ts index 8a5a5a49d..567a19653 100644 --- a/dev/src/reference/document-reference.ts +++ b/dev/src/reference/document-reference.ts @@ -205,11 +205,14 @@ export class DocumentReference< * }); * ``` */ - get(): Promise> { + get(options?: firestore.ReadOptions): Promise> { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_DOC_REF_GET, () => { - return this._firestore.getAll(this).then(([result]) => result); + if (options) + return this._firestore.getAll(this, options).then(([result]) => result); + else + return this._firestore.getAll(this).then(([result]) => result); } ); } @@ -259,7 +262,7 @@ export class DocumentReference< * }); * ``` */ - listCollections(): Promise> { + listCollections(options?: firestore.FirestoreRequestOptions): Promise> { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_DOC_REF_LIST_COLLECTIONS, () => { @@ -272,7 +275,7 @@ export class DocumentReference< .request< api.IListCollectionIdsRequest, string[] - >('listCollectionIds', request, tag) + >('listCollectionIds', request, tag, options) .then(collectionIds => { const collections: Array = []; diff --git a/dev/src/reference/query-util.ts b/dev/src/reference/query-util.ts index 0ac50f775..3460f90cc 100644 --- a/dev/src/reference/query-util.ts +++ b/dev/src/reference/query-util.ts @@ -64,6 +64,7 @@ export class QueryUtil< _getResponse( query: Template, + options: firestore.FirestoreRequestOptions | undefined, transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions, retryWithCursor = true, explainOptions?: firestore.ExplainOptions @@ -79,9 +80,10 @@ export class QueryUtil< this._stream( query, + options, transactionOrReadTime, retryWithCursor, - explainOptions + explainOptions, ) .on('error', err => { reject(wrapError(err, stack)); @@ -150,7 +152,7 @@ export class QueryUtil< return Date.now() - startTime >= totalTimeout; } - stream(query: Template): NodeJS.ReadableStream { + stream(query: Template, options?: firestore.FirestoreRequestOptions): NodeJS.ReadableStream { if (this._queryOptions.limitType === LimitType.Last) { throw new Error( 'Query results for queries that include limitToLast() ' + @@ -158,7 +160,7 @@ export class QueryUtil< ); } - const responseStream = this._stream(query); + const responseStream = this._stream(query, options); const transform = new Transform({ objectMode: true, transform(chunk, encoding, callback) { @@ -173,6 +175,7 @@ export class QueryUtil< _stream( query: Template, + requestOptions: firestore.FirestoreRequestOptions | undefined, transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions, retryWithCursor = true, explainOptions?: firestore.ExplainOptions @@ -283,7 +286,8 @@ export class QueryUtil< methodName, /* bidirectional= */ false, request, - tag + tag, + requestOptions ); backendStream.on('error', err => { backendStream.unpipe(stream); diff --git a/dev/src/reference/query.ts b/dev/src/reference/query.ts index f8d407190..2d61033e4 100644 --- a/dev/src/reference/query.ts +++ b/dev/src/reference/query.ts @@ -1134,11 +1134,11 @@ export class Query< * }); * ``` */ - async get(): Promise> { + async get(options?: firestore.FirestoreRequestOptions): Promise> { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_QUERY_GET, async () => { - const {result} = await this._get(); + const {result} = await this._get(options); return result; } ); @@ -1153,12 +1153,14 @@ export class Query< * from the query execution (if any), and the query results (if any). */ async explain( - options?: firestore.ExplainOptions + options?: firestore.ExplainOptions, + requestOptions?: firestore.FirestoreRequestOptions ): Promise>> { if (options === undefined) { options = {}; } const {result, explainMetrics} = await this._getResponse( + requestOptions, undefined, options ); @@ -1178,9 +1180,10 @@ export class Query< * transaction, or timestamp to use as read time. */ async _get( - transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions + options: firestore.FirestoreRequestOptions | undefined, + transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions, ): Promise>> { - const result = await this._getResponse(transactionOrReadTime); + const result = await this._getResponse(options, transactionOrReadTime, undefined); if (!result.result) { throw new Error('No QuerySnapshot result'); } @@ -1190,11 +1193,13 @@ export class Query< } _getResponse( + options: firestore.FirestoreRequestOptions | undefined, transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions, - explainOptions?: firestore.ExplainOptions + explainOptions?: firestore.ExplainOptions, ): Promise>> { return this._queryUtil._getResponse( this, + options, transactionOrReadTime, true, explainOptions @@ -1222,8 +1227,8 @@ export class Query< * }); * ``` */ - stream(): NodeJS.ReadableStream { - return this._queryUtil.stream(this); + stream(options?: firestore.FirestoreRequestOptions): NodeJS.ReadableStream { + return this._queryUtil.stream(this, options); } /** @@ -1255,7 +1260,8 @@ export class Query< * ``` */ explainStream( - explainOptions?: firestore.ExplainOptions + explainOptions?: firestore.ExplainOptions, + requestOptions?: firestore.FirestoreRequestOptions ): NodeJS.ReadableStream { if (explainOptions === undefined) { explainOptions = {}; @@ -1267,7 +1273,7 @@ export class Query< ); } - const responseStream = this._stream(undefined, explainOptions); + const responseStream = this._stream(requestOptions, undefined, explainOptions); const transform = new Transform({ objectMode: true, transform( @@ -1481,11 +1487,13 @@ export class Query< * @returns A stream of document results, optionally preceded by a transaction response. */ _stream( + options: firestore.FirestoreRequestOptions | undefined, transactionOrReadTime?: Uint8Array | Timestamp | api.ITransactionOptions, explainOptions?: firestore.ExplainOptions ): NodeJS.ReadableStream { return this._queryUtil._stream( this, + options, transactionOrReadTime, true, explainOptions @@ -1517,7 +1525,7 @@ export class Query< * unsubscribe(); * ``` */ - onSnapshot( + onSnapshot( onNext: (snapshot: QuerySnapshot) => void, onError?: (error: Error) => void ): () => void { diff --git a/dev/src/reference/vector-query.ts b/dev/src/reference/vector-query.ts index 3fe36194f..58a176441 100644 --- a/dev/src/reference/vector-query.ts +++ b/dev/src/reference/vector-query.ts @@ -113,12 +113,13 @@ export class VectorQuery< * from the query execution (if any), and the query results (if any). */ async explain( - options?: firestore.ExplainOptions + options?: firestore.ExplainOptions, + requestOptions?: firestore.FirestoreRequestOptions ): Promise>> { if (options === undefined) { options = {}; } - const {result, explainMetrics} = await this._getResponse(options); + const {result, explainMetrics} = await this._getResponse(requestOptions, options); if (!explainMetrics) { throw new Error('No explain results'); } @@ -130,8 +131,8 @@ export class VectorQuery< * * @returns A promise that will be resolved with the results of the query. */ - async get(): Promise> { - const {result} = await this._getResponse(); + async get(options?: firestore.FirestoreRequestOptions): Promise> { + const {result} = await this._getResponse(options); if (!result) { throw new Error('No VectorQuerySnapshot result'); } @@ -139,10 +140,12 @@ export class VectorQuery< } _getResponse( - explainOptions?: firestore.ExplainOptions + options: firestore.FirestoreRequestOptions | undefined, + explainOptions?: firestore.ExplainOptions, ): Promise>> { return this._queryUtil._getResponse( this, + options, /*transactionOrReadTime*/ undefined, // VectorQuery cannot be retried with cursors as they do not support cursors yet. /*retryWithCursor*/ false, @@ -158,9 +161,10 @@ export class VectorQuery< * @internal * @returns A stream of document results. */ - _stream(transactionId?: Uint8Array): NodeJS.ReadableStream { + _stream(options: firestore.FirestoreRequestOptions | undefined, transactionId?: Uint8Array): NodeJS.ReadableStream { return this._queryUtil._stream( this, + options, transactionId, /*retryWithCursor*/ false ); diff --git a/dev/src/transaction.ts b/dev/src/transaction.ts index ed3a8d477..24a3c72a6 100644 --- a/dev/src/transaction.ts +++ b/dev/src/transaction.ts @@ -135,7 +135,8 @@ export class Transaction implements firestore.Transaction { * @return {Promise} A QuerySnapshot for the retrieved data. */ get( - query: firestore.Query + query: firestore.Query, + options?: firestore.FirestoreRequestOptions ): Promise>; /** @@ -146,7 +147,8 @@ export class Transaction implements firestore.Transaction { * @return {Promise} A DocumentSnapshot for the read data. */ get( - documentRef: firestore.DocumentReference + documentRef: firestore.DocumentReference, + options?: firestore.FirestoreRequestOptions ): Promise>; /** @@ -165,7 +167,8 @@ export class Transaction implements firestore.Transaction { AggregateSpecType, AppModelType, DbModelType - > + >, + options?: firestore.FirestoreRequestOptions ): Promise< AggregateQuerySnapshot >; @@ -201,7 +204,8 @@ export class Transaction implements firestore.Transaction { refOrQuery: | firestore.DocumentReference | firestore.Query - | firestore.AggregateQuery + | firestore.AggregateQuery, + options?: firestore.FirestoreRequestOptions ): Promise< | DocumentSnapshot | QuerySnapshot @@ -215,7 +219,7 @@ export class Transaction implements firestore.Transaction { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_TRANSACTION_GET_DOCUMENT, () => { - return this.withLazyStartedTransaction(refOrQuery, this.getSingleFn); + return this.withLazyStartedTransaction(refOrQuery, this.getSingleFn, options); } ); } @@ -226,7 +230,7 @@ export class Transaction implements firestore.Transaction { ? SPAN_NAME_TRANSACTION_GET_QUERY : SPAN_NAME_TRANSACTION_GET_AGGREGATION_QUERY, () => { - return this.withLazyStartedTransaction(refOrQuery, this.getQueryFn); + return this.withLazyStartedTransaction(refOrQuery, this.getQueryFn, options); } ); } @@ -280,9 +284,12 @@ export class Transaction implements firestore.Transaction { 1 ); + const parsed = parseGetAllArguments(documentRefsOrReadOptions); + return this.withLazyStartedTransaction( - parseGetAllArguments(documentRefsOrReadOptions), - this.getBatchFn + parsed.parsedGetAllArguments, + this.getBatchFn, + parsed.readOptions ); } @@ -485,7 +492,7 @@ export class Transaction implements firestore.Transaction { * @private * @internal */ - async commit(): Promise { + async commit(options?: firestore.FirestoreRequestOptions): Promise { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_TRANSACTION_COMMIT, async () => { @@ -504,10 +511,10 @@ export class Transaction implements firestore.Transaction { return; } - await this._writeBatch._commit({ + await this._writeBatch._commit(options, { transactionId, requestTag: this._requestTag, - }); + }); //TODO: add public options this._transactionIdPromise = undefined; this._prevTransactionId = transactionId; }, @@ -526,7 +533,7 @@ export class Transaction implements firestore.Transaction { * @private * @internal */ - async rollback(): Promise { + async rollback(options?: firestore.FirestoreRequestOptions | undefined): Promise { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_TRANSACTION_ROLLBACK, async () => { @@ -558,7 +565,7 @@ export class Transaction implements firestore.Transaction { // Rollback can be done concurrently thereby reducing latency caused by // otherwise blocking. this._firestore - .request('rollback', request, this._requestTag) + .request('rollback', request, this._requestTag, options) .catch(err => { logger( 'Firestore.runTransaction', @@ -648,7 +655,8 @@ export class Transaction implements firestore.Transaction { * context. */ async runTransactionOnce( - updateFunction: (transaction: Transaction) => Promise + updateFunction: (transaction: Transaction) => Promise, + options?: firestore.FirestoreRequestOptions ): Promise { try { const promise = updateFunction(this); @@ -659,7 +667,7 @@ export class Transaction implements firestore.Transaction { } const result = await promise; if (this._writeBatch) { - await this.commit(); + await this.commit(options); } return result; } catch (err) { @@ -684,22 +692,24 @@ export class Transaction implements firestore.Transaction { resultFn: ( this: typeof this, param: TParam, - opts: Uint8Array | api.ITransactionOptions | Timestamp - ) => Promise<{transaction?: Uint8Array; result: TResult}> + opts: Uint8Array | api.ITransactionOptions | Timestamp, + requestOptions: firestore.FirestoreRequestOptions | undefined + ) => Promise<{transaction?: Uint8Array; result: TResult}>, + options: firestore.FirestoreRequestOptions | undefined ): Promise { if (this._transactionIdPromise) { // Simply queue this subsequent read operation after the first read // operation has resolved and we don't expect a transaction ID in the // response because we are not starting a new transaction return this._transactionIdPromise - .then(opts => resultFn.call(this, param, opts)) + .then(opts => resultFn.call(this, param, opts, options)) .then(r => r.result); } else { if (this._readOnlyReadTime) { // We do not start a transaction for read-only transactions // do not set _prevTransactionId return resultFn - .call(this, param, this._readOnlyReadTime) + .call(this, param, this._readOnlyReadTime, options) .then(r => r.result); } else { // This is the first read of the transaction so we create the appropriate @@ -713,7 +723,7 @@ export class Transaction implements firestore.Transaction { opts.readOnly = {}; } - const resultPromise = resultFn.call(this, param, opts); + const resultPromise = resultFn.call(this, param, opts, options); // Ensure the _transactionIdPromise is set synchronously so that // subsequent operations will not race to start another transaction @@ -737,7 +747,8 @@ export class Transaction implements firestore.Transaction { DbModelType extends firestore.DocumentData, >( document: DocumentReference, - opts: Uint8Array | api.ITransactionOptions | Timestamp + opts: Uint8Array | api.ITransactionOptions | Timestamp, + requestOptions: firestore.FirestoreRequestOptions | undefined ): Promise<{ transaction?: Uint8Array; result: DocumentSnapshot; @@ -751,7 +762,7 @@ export class Transaction implements firestore.Transaction { const { transaction, result: [result], - } = await documentReader._get(this._requestTag); + } = await documentReader._get(this._requestTag, requestOptions); return {transaction, result}; } @@ -766,7 +777,8 @@ export class Transaction implements firestore.Transaction { documents: Array>; fieldMask?: FieldPath[]; }, - opts: Uint8Array | api.ITransactionOptions | Timestamp + opts: Uint8Array | api.ITransactionOptions | Timestamp, + requestOptions: firestore.FirestoreRequestOptions | undefined ): Promise<{ transaction?: Uint8Array; result: DocumentSnapshot[]; @@ -780,7 +792,7 @@ export class Transaction implements firestore.Transaction { fieldMask, opts ); - return documentReader._get(this._requestTag); + return documentReader._get(this._requestTag, requestOptions); } ); } @@ -790,12 +802,13 @@ export class Transaction implements firestore.Transaction { TQuery extends Query | AggregateQuery, >( query: TQuery, - opts: Uint8Array | api.ITransactionOptions | Timestamp + opts: Uint8Array | api.ITransactionOptions | Timestamp, + requestOptions: firestore.FirestoreRequestOptions | undefined ): Promise<{ transaction?: Uint8Array; result: Awaited>['result']; }> { - return query._get(opts); + return query._get(requestOptions, opts); } } @@ -817,8 +830,11 @@ export function parseGetAllArguments< | firestore.ReadOptions > ): { - documents: Array>; - fieldMask: FieldPath[] | undefined; + parsedGetAllArguments: { + documents: Array>; + fieldMask: FieldPath[] | undefined; + } + readOptions: firestore.ReadOptions | undefined; } { let documents: Array>; let readOptions: firestore.ReadOptions | undefined = undefined; @@ -857,7 +873,13 @@ export function parseGetAllArguments< FieldPath.fromArgument(fieldPath) ) : undefined; - return {fieldMask, documents}; + return { + parsedGetAllArguments: { + fieldMask, + documents + }, + readOptions + }; } /** diff --git a/dev/src/types.ts b/dev/src/types.ts index ac7a62d22..ed753cf05 100644 --- a/dev/src/types.ts +++ b/dev/src/types.ts @@ -21,7 +21,7 @@ import { WithFieldValue, } from '@google-cloud/firestore'; -import {CallOptions} from 'google-gax'; +import {CallOptions, CancellablePromise} from 'google-gax'; import {Duplex} from 'stream'; import {google} from '../protos/firestore_v1_proto_api'; @@ -106,7 +106,7 @@ export type FirestoreStreamingMethod = export type UnaryMethod = ( request: Req, callOptions: CallOptions -) => Promise<[Resp, unknown, unknown]>; +) => CancellablePromise<[Resp, unknown, unknown]>; // We don't have type information for the npm package // `functional-red-black-tree`. diff --git a/dev/src/watch.ts b/dev/src/watch.ts index 737272abf..c0c58fa89 100644 --- a/dev/src/watch.ts +++ b/dev/src/watch.ts @@ -296,7 +296,7 @@ abstract class Watch< this.onError = onError; this.docTree = rbtree(this.getComparator()); - this.initStream(); + this.initStream(undefined); const unsubscribe: () => void = () => { logger('Watch.onSnapshot', this.requestTag, 'Unsubscribe called'); @@ -394,7 +394,7 @@ abstract class Watch< * @private * @internal */ - private maybeReopenStream(err: GoogleError): void { + private maybeReopenStream(options: firestore.FirestoreRequestOptions | undefined, err: GoogleError): void { if (this.isActive && !this.isPermanentWatchError(err)) { logger( 'Watch.maybeReopenStream', @@ -408,7 +408,7 @@ abstract class Watch< this.backoff.resetToMax(); } - this.initStream(); + this.initStream(options); } else { this.closeStream(err); } @@ -420,7 +420,7 @@ abstract class Watch< * @private * @internal */ - private resetIdleTimeout(): void { + private resetIdleTimeout(options: firestore.FirestoreRequestOptions | undefined): void { if (this.idleTimeoutHandle) { clearTimeout(this.idleTimeoutHandle); } @@ -436,7 +436,7 @@ abstract class Watch< const error = new GoogleError('Watch stream idle timeout'); error.code = Status.UNKNOWN; - this.maybeReopenStream(error); + this.maybeReopenStream(options, error); }, WATCH_IDLE_TIMEOUT_MS); } @@ -445,13 +445,13 @@ abstract class Watch< * @private * @internal */ - private resetStream(): void { + private resetStream(options: firestore.FirestoreRequestOptions | undefined): void { logger('Watch.resetStream', this.requestTag, 'Restarting stream'); if (this.currentStream) { this.currentStream.end(); this.currentStream = null; } - this.initStream(); + this.initStream(options); } /** @@ -459,7 +459,7 @@ abstract class Watch< * @private * @internal */ - private initStream(): void { + private initStream(options: firestore.FirestoreRequestOptions | undefined): void { this.backoff .backoffAndWait() .then(async () => { @@ -485,7 +485,8 @@ abstract class Watch< 'listen', /* bidirectional= */ true, request, - this.requestTag + this.requestTag, + options ) .then(backendStream => { if (!this.isActive) { @@ -505,16 +506,16 @@ abstract class Watch< } logger('Watch.initStream', this.requestTag, 'Opened new stream'); this.currentStream = backendStream; - this.resetIdleTimeout(); + this.resetIdleTimeout(options); this.currentStream!.on('data', (proto: api.IListenResponse) => { - this.resetIdleTimeout(); - this.onData(proto); + this.resetIdleTimeout(options); + this.onData(options, proto); }) .on('error', err => { if (this.currentStream === backendStream) { this.currentStream = null; - this.maybeReopenStream(err); + this.maybeReopenStream(options, err); } }) .on('end', () => { @@ -523,7 +524,7 @@ abstract class Watch< const err = new GoogleError('Stream ended unexpectedly'); err.code = Status.UNKNOWN; - this.maybeReopenStream(err); + this.maybeReopenStream(options, err); } }); this.currentStream!.resume(); @@ -540,7 +541,7 @@ abstract class Watch< * @private * @internal */ - private onData(proto: api.IListenResponse): void { + private onData(options: firestore.FirestoreRequestOptions | undefined, proto: api.IListenResponse): void { if (proto.targetChange) { logger('Watch.onData', this.requestTag, 'Processing target change'); const change = proto.targetChange; @@ -641,7 +642,7 @@ abstract class Watch< // We need to remove all the current results. this.resetDocs(); // The filter didn't match, so re-issue the query. - this.resetStream(); + this.resetStream(options); } } else { this.closeStream( diff --git a/dev/src/write-batch.ts b/dev/src/write-batch.ts index 62fe7d996..a89de5189 100644 --- a/dev/src/write-batch.ts +++ b/dev/src/write-batch.ts @@ -575,7 +575,7 @@ export class WriteBatch implements firestore.WriteBatch { * }); * ``` */ - commit(): Promise { + commit(options?: firestore.FirestoreRequestOptions): Promise { return this._firestore._traceUtil.startActiveSpan( SPAN_NAME_BATCH_COMMIT, async () => { @@ -585,7 +585,7 @@ export class WriteBatch implements firestore.WriteBatch { // Commits should also be retried when they fail with status code ABORTED. const retryCodes = [StatusCode.ABORTED, ...getRetryCodes('commit')]; - return this._commit({retryCodes}) + return this._commit(options, {retryCodes}) //TODO: add public options .then(response => { return (response.writeResults || []).map( writeResult => @@ -618,12 +618,15 @@ export class WriteBatch implements firestore.WriteBatch { * this request. * @returns A Promise that resolves when this batch completes. */ - async _commit(commitOptions?: { - transactionId?: Uint8Array; - requestTag?: string; - retryCodes?: number[]; - methodName?: FirestoreUnaryMethod; - }): Promise { + async _commit( + requestOptions: firestore.FirestoreRequestOptions | undefined, + commitOptions?: { + transactionId?: Uint8Array; + requestTag?: string; + retryCodes?: number[]; + methodName?: FirestoreUnaryMethod; + } + ): Promise { // Note: We don't call `verifyNotCommitted()` to allow for retries. this._committed = true; @@ -652,6 +655,7 @@ export class WriteBatch implements firestore.WriteBatch { commitOptions?.methodName || 'commit', request as Req, tag, + requestOptions, commitOptions?.retryCodes ); } diff --git a/dev/test/abort-signal-integration.ts b/dev/test/abort-signal-integration.ts new file mode 100644 index 000000000..034dac176 --- /dev/null +++ b/dev/test/abort-signal-integration.ts @@ -0,0 +1,198 @@ +/*! + * Copyright 2024 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {describe, it} from 'mocha'; +import {expect} from 'chai'; +import {Duplex} from 'stream'; + +import {Firestore} from '../src'; +import {createInstance} from './util/helpers'; + +describe('AbortSignal Integration', () => { + let firestore: Firestore; + + beforeEach(() => { + return createInstance().then(firestoreInstance => { + firestore = firestoreInstance; + }); + }); + + afterEach(() => firestore.terminate()); + + describe('Document operations with AbortSignal', () => { + it('should support AbortSignal in get() operation', async () => { + const controller = new AbortController(); + + // Mock the batchGetDocuments to return a stream that can be cancelled + const overrides = { + batchGetDocuments: () => { + const stream = new Duplex({ + objectMode: true, + read() { + // Required _read implementation + }, + write(chunk, encoding, callback) { + callback(); + } + }); + (stream as any).cancel = () => { + stream.destroy(); + }; + + // Simulate a slow response + setTimeout(() => { + stream.push({ + missing: 'projects/test-project/databases/(default)/documents/coll/doc', + readTime: {seconds: 0, nanos: 0}, + }); + stream.push(null); + }, 100); + + return stream; + }, + }; + + return createInstance(overrides).then(async firestoreInstance => { + const testDocRef = firestoreInstance.doc('coll/doc'); + + // Start the get operation with AbortSignal + const getPromise = testDocRef.get({abortSignal: controller.signal}); + + // Abort after a short delay + setTimeout(() => controller.abort(), 10); + + try { + await getPromise; + expect.fail('Should have thrown due to abort'); + } catch (error) { + expect(error.message).to.equal('The operation was aborted'); + } + + await firestoreInstance.terminate(); + }); + }); + }); + + describe('Query operations with AbortSignal', () => { + it('should support AbortSignal in query get() operation', async () => { + const controller = new AbortController(); + + // Mock the runQuery to return a stream that can be cancelled + const overrides = { + runQuery: () => { + const stream = new Duplex({ + objectMode: true, + read() { + // Required _read implementation + }, + write(chunk, encoding, callback) { + callback(); + } + }); + (stream as any).cancel = () => { + stream.destroy(); + }; + + // Simulate a slow response + setTimeout(() => { + stream.push({readTime: {seconds: 0, nanos: 0}}); + stream.push(null); + }, 100); + + return stream; + }, + }; + + return createInstance(overrides).then(async firestoreInstance => { + const query = firestoreInstance.collection('coll').where('field', '==', 'value'); + + // Start the query operation with AbortSignal + const queryPromise = query.get({abortSignal: controller.signal}); + + // Abort after a short delay + setTimeout(() => controller.abort(), 10); + + try { + await queryPromise; + expect.fail('Should have thrown due to abort'); + } catch (error) { + expect(error.message).to.equal('The operation was aborted'); + } + + await firestoreInstance.terminate(); + }); + }); + }); + + describe('Batch operations with AbortSignal', () => { + it('should support AbortSignal in getAll() operation', async () => { + const controller = new AbortController(); + + // Mock the batchGetDocuments to return a stream that can be cancelled + const overrides = { + batchGetDocuments: () => { + const stream = new Duplex({ + objectMode: true, + read() { + // Required _read implementation + }, + write(chunk, encoding, callback) { + callback(); + } + }); + (stream as any).cancel = () => { + stream.destroy(); + }; + + // Simulate a slow response + setTimeout(() => { + stream.push({ + missing: 'projects/test-project/databases/(default)/documents/coll/doc1', + readTime: {seconds: 0, nanos: 0}, + }); + stream.push({ + missing: 'projects/test-project/databases/(default)/documents/coll/doc2', + readTime: {seconds: 0, nanos: 0}, + }); + stream.push(null); + }, 100); + + return stream; + }, + }; + + return createInstance(overrides).then(async firestoreInstance => { + const doc1 = firestoreInstance.doc('coll/doc1'); + const doc2 = firestoreInstance.doc('coll/doc2'); + + // Start the getAll operation with AbortSignal + const getAllPromise = firestoreInstance.getAll(doc1, doc2, {abortSignal: controller.signal}); + + // Abort after a short delay + setTimeout(() => controller.abort(), 10); + + try { + await getAllPromise; + expect.fail('Should have thrown due to abort'); + } catch (error) { + expect(error.message).to.equal('The operation was aborted'); + } + + await firestoreInstance.terminate(); + }); + }); + }); +}); \ No newline at end of file diff --git a/dev/test/abort-signal.ts b/dev/test/abort-signal.ts new file mode 100644 index 000000000..3f0a8c5b5 --- /dev/null +++ b/dev/test/abort-signal.ts @@ -0,0 +1,111 @@ +/*! + * Copyright 2024 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {describe, it} from 'mocha'; +import {expect} from 'chai'; +import {AbortUtil, Cancellable} from '../src/abort-util'; + +describe('AbortUtil', () => { + describe('throwIfAborted', () => { + it('should not throw if signal is null', () => { + expect(() => AbortUtil.throwIfAborted(null)).to.not.throw(); + }); + + it('should not throw if signal is not aborted', () => { + const controller = new AbortController(); + expect(() => AbortUtil.throwIfAborted(controller.signal)).to.not.throw(); + }); + + it('should throw if signal is aborted', () => { + const controller = new AbortController(); + controller.abort(); + expect(() => AbortUtil.throwIfAborted(controller.signal)).to.throw('The operation was aborted'); + }); + }); + + describe('makeCancellable', () => { + it('should return original promise if no signal provided', async () => { + const originalPromise = Promise.resolve('test'); + const cancellable: Cancellable = { cancel: () => {} }; + + const result = await AbortUtil.makeCancellable(originalPromise, cancellable, null); + expect(result).to.equal('test'); + }); + + it('should return original promise if signal is not aborted', async () => { + const controller = new AbortController(); + const originalPromise = Promise.resolve('test'); + const cancellable: Cancellable = { cancel: () => {} }; + + const result = await AbortUtil.makeCancellable(originalPromise, cancellable, controller.signal); + expect(result).to.equal('test'); + }); + + it('should throw if signal is already aborted', async () => { + const controller = new AbortController(); + controller.abort(); + const originalPromise = Promise.resolve('test'); + const cancellable: Cancellable = { cancel: () => {} }; + + try { + await AbortUtil.makeCancellable(originalPromise, cancellable, controller.signal); + expect.fail('Should have thrown'); + } catch (error) { + expect(error.message).to.equal('The operation was aborted'); + } + }); + + it('should cancel and reject when signal is aborted during operation', async () => { + const controller = new AbortController(); + let cancelCalled = false; + const cancellable: Cancellable = { + cancel: () => { cancelCalled = true; } + }; + + // Create a promise that never resolves + const originalPromise = new Promise(() => {}); + + const cancellablePromise = AbortUtil.makeCancellable(originalPromise, cancellable, controller.signal); + + // Abort after a short delay + setTimeout(() => controller.abort(), 10); + + try { + await cancellablePromise; + expect.fail('Should have thrown'); + } catch (error) { + expect(error.message).to.equal('The operation was aborted'); + expect(cancelCalled).to.be.true; + } + }); + + it('should resolve normally if promise completes before abort', async () => { + const controller = new AbortController(); + const cancellable: Cancellable = { cancel: () => {} }; + + // Create a promise that resolves quickly + const originalPromise = new Promise(resolve => setTimeout(() => resolve('success'), 10)); + + const cancellablePromise = AbortUtil.makeCancellable(originalPromise, cancellable, controller.signal); + + // Abort after a longer delay + setTimeout(() => controller.abort(), 50); + + const result = await cancellablePromise; + expect(result).to.equal('success'); + }); + }); +}); \ No newline at end of file diff --git a/types/firestore.d.ts b/types/firestore.d.ts index cfff47c98..5d4d26677 100644 --- a/types/firestore.d.ts +++ b/types/firestore.d.ts @@ -487,6 +487,10 @@ declare namespace FirebaseFirestore { tracerProvider?: any; } + export interface FirestoreRequestOptions { + abortSignal?: AbortSignal; + } + /** Options to configure a read-only transaction. */ export interface ReadOnlyTransactionOptions { /** Set to true to indicate a read-only transaction. */ @@ -1360,7 +1364,7 @@ declare namespace FirebaseFirestore { * calls. By providing a `fieldMask`, these calls can be configured to only * return a subset of fields. */ - export interface ReadOptions { + export interface ReadOptions extends FirestoreRequestOptions { /** * Specifies the set of fields to return and reduces the amount of data * transmitted by the backend.