From 423ef62ec446e10e31b4a4d16e4449799a657b55 Mon Sep 17 00:00:00 2001 From: Thibault Wittemberg Date: Tue, 20 Jan 2026 14:55:03 +0100 Subject: [PATCH 1/2] codex: fix cancellation and sendability --- .../AsyncChannels/AsyncBufferedChannel.swift | 2 +- .../AsyncPassthroughSubject.swift | 2 +- .../Combiners/Merge/MergeStateMachine.swift | 2 +- .../Operators/AsyncMulticastSequence.swift | 56 ++++++++++-------- .../Merge/AsyncMergeSequenceTests.swift | 58 +++++++++++++++++++ .../AsyncMulticastSequenceTests.swift | 55 ++++++++++++++++++ 6 files changed, 147 insertions(+), 28 deletions(-) diff --git a/Sources/AsyncChannels/AsyncBufferedChannel.swift b/Sources/AsyncChannels/AsyncBufferedChannel.swift index 8c97127..7ea9a55 100644 --- a/Sources/AsyncChannels/AsyncBufferedChannel.swift +++ b/Sources/AsyncChannels/AsyncBufferedChannel.swift @@ -29,7 +29,7 @@ import OrderedCollections /// sut.send(3) /// sut.finish() /// ``` -public final class AsyncBufferedChannel: AsyncSequence, Sendable { +public final class AsyncBufferedChannel: AsyncSequence, Sendable { public typealias Element = Element public typealias AsyncIterator = Iterator diff --git a/Sources/AsyncSubjects/AsyncPassthroughSubject.swift b/Sources/AsyncSubjects/AsyncPassthroughSubject.swift index 61690f1..2badeb9 100644 --- a/Sources/AsyncSubjects/AsyncPassthroughSubject.swift +++ b/Sources/AsyncSubjects/AsyncPassthroughSubject.swift @@ -30,7 +30,7 @@ /// passthrough.send(2) /// passthrough.send(.finished) /// ``` -public final class AsyncPassthroughSubject: AsyncSubject { +public final class AsyncPassthroughSubject: AsyncSubject { public typealias Element = Element public typealias Failure = Never public typealias AsyncIterator = Iterator diff --git a/Sources/Combiners/Merge/MergeStateMachine.swift b/Sources/Combiners/Merge/MergeStateMachine.swift index 18c5df3..dd848b1 100644 --- a/Sources/Combiners/Merge/MergeStateMachine.swift +++ b/Sources/Combiners/Merge/MergeStateMachine.swift @@ -240,7 +240,7 @@ struct MergeStateMachine: Sendable { } } - if case .termination = regulatedElement, case .element(.failure) = regulatedElement { + if case .element(.failure) = regulatedElement { self.task.cancel() } diff --git a/Sources/Operators/AsyncMulticastSequence.swift b/Sources/Operators/AsyncMulticastSequence.swift index 9fd6a32..f106069 100644 --- a/Sources/Operators/AsyncMulticastSequence.swift +++ b/Sources/Operators/AsyncMulticastSequence.swift @@ -105,37 +105,43 @@ where Base.Element == Subject.Element, Subject.Failure == Error, Base.AsyncItera } func next() async { - await Task { - let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in - switch state { - case .available(let iterator): - state = .busy - return (true, iterator) - case .busy: - return (false, nil) - } - } - - guard canAccessBase, var iterator = iterator else { return } - - let toSend: Result - do { - let element = try await iterator.next() - toSend = .success(element) - } catch { - toSend = .failure(error) + guard !Task.isCancelled else { return } + + let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in + switch state { + case .available(let iterator): + state = .busy + return (true, iterator) + case .busy: + return (false, nil) } + } + guard canAccessBase, var iterator = iterator else { return } + defer { self.state.withCriticalRegion { state in state = .available(iterator) } + } - switch toSend { - case .success(.some(let element)): self.subject.send(element) - case .success(.none): self.subject.send(.finished) - case .failure(let error): self.subject.send(.failure(error)) - } - }.value + guard !Task.isCancelled else { return } + + let toSend: Result + do { + let element = try await iterator.next() + toSend = .success(element) + } catch { + guard !Task.isCancelled else { return } + toSend = .failure(error) + } + + guard !Task.isCancelled else { return } + + switch toSend { + case .success(.some(let element)): self.subject.send(element) + case .success(.none): self.subject.send(.finished) + case .failure(let error): self.subject.send(.failure(error)) + } } public func makeAsyncIterator() -> AsyncIterator { diff --git a/Tests/Combiners/Merge/AsyncMergeSequenceTests.swift b/Tests/Combiners/Merge/AsyncMergeSequenceTests.swift index 8f57a8e..995f8d4 100644 --- a/Tests/Combiners/Merge/AsyncMergeSequenceTests.swift +++ b/Tests/Combiners/Merge/AsyncMergeSequenceTests.swift @@ -41,6 +41,34 @@ private struct TimedAsyncSequence: AsyncSequence, AsyncIteratorProtocol } } +private struct CancellationAwareSequence: AsyncSequence, AsyncIteratorProtocol { + typealias Element = Element + typealias AsyncIterator = CancellationAwareSequence + + let onStart: @Sendable () -> Void + let onCancel: @Sendable () -> Void + var hasStarted = false + + mutating func next() async throws -> Element? { + if !hasStarted { + hasStarted = true + onStart() + } + + do { + try await Task.sleep(nanoseconds: 5_000_000_000) + return nil + } catch { + onCancel() + return nil + } + } + + func makeAsyncIterator() -> AsyncIterator { + self + } +} + final class AsyncMergeSequenceTests: XCTestCase { func testMerge_merges_sequences_according_to_the_timeline_using_asyncSequences() async throws { // -- 0 ------------------------------- 1000 ----------------------------- 2000 - @@ -306,4 +334,34 @@ final class AsyncMergeSequenceTests: XCTestCase { task.cancel() } + + func testMerge_cancels_other_bases_on_error() async { + let baseStartedExpectation = expectation(description: "The blocking base has started") + let baseCancelledExpectation = expectation(description: "The blocking base has been cancelled") + + let blockingBase = CancellationAwareSequence( + onStart: { baseStartedExpectation.fulfill() }, + onCancel: { baseCancelledExpectation.fulfill() } + ) + let failingBase = TimedAsyncSequence(intervalInMills: [0, 0], sequence: [1, 2], indexOfError: 1) + + let sut = merge(failingBase, blockingBase) + var iterator = sut.makeAsyncIterator() + + do { + _ = try await iterator.next() + } catch { + XCTFail("The first element should not fail") + } + await fulfillment(of: [baseStartedExpectation], timeout: 1) + + do { + _ = try await iterator.next() + XCTFail("The iteration should fail") + } catch { + XCTAssertEqual(error as? MockError, MockError(code: 1)) + } + + await fulfillment(of: [baseCancelledExpectation], timeout: 1) + } } diff --git a/Tests/Operators/AsyncMulticastSequenceTests.swift b/Tests/Operators/AsyncMulticastSequenceTests.swift index a77e4cf..bd60335 100644 --- a/Tests/Operators/AsyncMulticastSequenceTests.swift +++ b/Tests/Operators/AsyncMulticastSequenceTests.swift @@ -39,6 +39,39 @@ private class SpyAsyncSequenceForNumberOfIterators: AsyncSequence { } } +private struct CancellationAwareSequence: AsyncSequence { + typealias Element = Int + typealias AsyncIterator = Iterator + + let onStart: @Sendable () -> Void + let onCancel: @Sendable () -> Void + + func makeAsyncIterator() -> AsyncIterator { + Iterator(onStart: self.onStart, onCancel: self.onCancel) + } + + struct Iterator: AsyncIteratorProtocol { + let onStart: @Sendable () -> Void + let onCancel: @Sendable () -> Void + var hasStarted = false + + mutating func next() async throws -> Int? { + if !hasStarted { + hasStarted = true + onStart() + } + + do { + try await Task.sleep(nanoseconds: 5_000_000_000) + return nil + } catch { + onCancel() + return nil + } + } + } +} + final class AsyncMulticastSequenceTests: XCTestCase { func test_multiple_loops_receive_elements_from_single_baseIterator() { let taskHaveIterators = expectation(description: "All tasks have their iterator") @@ -156,4 +189,26 @@ final class AsyncMulticastSequenceTests: XCTestCase { XCTAssertEqual(error as? MockError, expectedError) } } + + func test_multicast_cancels_upstream_when_consumer_cancels() async { + let upstreamStartedExpectation = expectation(description: "Upstream started") + let upstreamCancelledExpectation = expectation(description: "Upstream cancelled") + + let base = CancellationAwareSequence( + onStart: { upstreamStartedExpectation.fulfill() }, + onCancel: { upstreamCancelledExpectation.fulfill() } + ) + let stream = AsyncThrowingPassthroughSubject() + let sut = base.multicast(stream).autoconnect() + + let task = Task { + var iterator = sut.makeAsyncIterator() + _ = try? await iterator.next() + } + + await fulfillment(of: [upstreamStartedExpectation], timeout: 1) + task.cancel() + + await fulfillment(of: [upstreamCancelledExpectation], timeout: 1) + } } From 3498132f4fb6aef7b0e53b119ba9361fb0c4759b Mon Sep 17 00:00:00 2001 From: Thibault Wittemberg Date: Tue, 20 Jan 2026 15:38:59 +0100 Subject: [PATCH 2/2] codex: address warnings and deprecations --- .../AsyncChannels/AsyncBufferedChannel.swift | 42 ++++++------ .../AsyncThrowingBufferedChannel.swift | 42 ++++++------ .../Combiners/Merge/MergeStateMachine.swift | 4 +- .../AsyncWithLatestFrom2Sequence.swift | 12 ++-- .../AsyncWithLatestFromSequence.swift | 6 +- Sources/Combiners/Zip/Zip2Runtime.swift | 12 ++-- Sources/Combiners/Zip/Zip3Runtime.swift | 12 ++-- Sources/Combiners/Zip/ZipRuntime.swift | 12 ++-- Sources/Creators/AsyncTimerSequence.swift | 6 +- .../Operators/AsyncMulticastSequence.swift | 67 ++++++++----------- Sources/Operators/AsyncPrependSequence.swift | 6 +- .../AsyncSwitchToLatestSequence.swift | 12 ++-- Sources/Supporting/Regulator.swift | 4 +- .../AsyncBufferedChannelTests.swift | 14 ++-- .../AsyncBufferedThrowingChannelTests.swift | 18 ++--- .../AsyncCurrentValueSubjectTests.swift | 8 +-- .../AsyncPassthroughSubjectTests.swift | 10 +-- .../AsyncReplaySubjectTests.swift | 8 +-- ...syncThrowingCurrentValueSubjectTests.swift | 12 ++-- ...AsyncThrowingPassthroughSubjectTests.swift | 16 ++--- .../AsyncThrowingReplaySubjectTests.swift | 12 ++-- .../AsyncWithLatestFrom2SequenceTests.swift | 14 ++-- .../AsyncWithLatestFromSequenceTests.swift | 10 +-- .../Combiners/Zip/AsyncZipSequenceTests.swift | 6 +- Tests/Creators/AsyncFailSequenceTests.swift | 2 +- Tests/Creators/AsyncJustSequenceTests.swift | 2 +- Tests/Creators/AsyncLazySequenceTests.swift | 2 +- .../AsyncThrowingJustSequenceTests.swift | 2 +- .../AsyncHandleEventsSequenceTests.swift | 6 +- .../AsyncMulticastSequenceTests.swift | 54 --------------- .../Operators/AsyncPrependSequenceTests.swift | 2 +- Tests/Operators/AsyncScanSequenceTests.swift | 2 +- .../AsyncSequence+FlatMapLatestTests.swift | 2 +- .../AsyncSwitchToLatestSequenceTests.swift | 8 +-- 34 files changed, 191 insertions(+), 256 deletions(-) diff --git a/Sources/AsyncChannels/AsyncBufferedChannel.swift b/Sources/AsyncChannels/AsyncBufferedChannel.swift index 7ea9a55..2f28efe 100644 --- a/Sources/AsyncChannels/AsyncBufferedChannel.swift +++ b/Sources/AsyncChannels/AsyncBufferedChannel.swift @@ -157,27 +157,7 @@ public final class AsyncBufferedChannel: AsyncSequence, Senda let awaitingId = self.generateId() let cancellation = ManagedCriticalState(false) - return await withTaskCancellationHandler { [state] in - let awaiting = state.withCriticalRegion { state -> Awaiting? in - cancellation.withCriticalRegion { cancellation in - cancellation = true - } - switch state { - case .awaiting(var awaitings): - let awaiting = awaitings.remove(.placeHolder(id: awaitingId)) - if awaitings.isEmpty { - state = .idle - } else { - state = .awaiting(awaitings) - } - return awaiting - default: - return nil - } - } - - awaiting?.continuation?.resume(returning: nil) - } operation: { + return await withTaskCancellationHandler { await withUnsafeContinuation { [state] (continuation: UnsafeContinuation) in let decision = state.withCriticalRegion { state -> AwaitingDecision in let isCancelled = cancellation.withCriticalRegion { $0 } @@ -218,6 +198,26 @@ public final class AsyncBufferedChannel: AsyncSequence, Senda onSuspend?() } } + } onCancel: { [state] in + let awaiting = state.withCriticalRegion { state -> Awaiting? in + cancellation.withCriticalRegion { cancellation in + cancellation = true + } + switch state { + case .awaiting(var awaitings): + let awaiting = awaitings.remove(.placeHolder(id: awaitingId)) + if awaitings.isEmpty { + state = .idle + } else { + state = .awaiting(awaitings) + } + return awaiting + default: + return nil + } + } + + awaiting?.continuation?.resume(returning: nil) } } diff --git a/Sources/AsyncChannels/AsyncThrowingBufferedChannel.swift b/Sources/AsyncChannels/AsyncThrowingBufferedChannel.swift index f34c80e..073f0b7 100644 --- a/Sources/AsyncChannels/AsyncThrowingBufferedChannel.swift +++ b/Sources/AsyncChannels/AsyncThrowingBufferedChannel.swift @@ -178,27 +178,7 @@ public final class AsyncThrowingBufferedChannel: AsyncS let awaitingId = self.generateId() let cancellation = ManagedCriticalState(false) - return try await withTaskCancellationHandler { [state] in - let awaiting = state.withCriticalRegion { state -> Awaiting? in - cancellation.withCriticalRegion { cancellation in - cancellation = true - } - switch state { - case .awaiting(var awaitings): - let awaiting = awaitings.remove(.placeHolder(id: awaitingId)) - if awaitings.isEmpty { - state = .idle - } else { - state = .awaiting(awaitings) - } - return awaiting - default: - return nil - } - } - - awaiting?.continuation?.resume(returning: nil) - } operation: { + return try await withTaskCancellationHandler { try await withUnsafeThrowingContinuation { [state] (continuation: UnsafeContinuation) in let decision = state.withCriticalRegion { state -> AwaitingDecision in let isCancelled = cancellation.withCriticalRegion { $0 } @@ -245,6 +225,26 @@ public final class AsyncThrowingBufferedChannel: AsyncS onSuspend?() } } + } onCancel: { [state] in + let awaiting = state.withCriticalRegion { state -> Awaiting? in + cancellation.withCriticalRegion { cancellation in + cancellation = true + } + switch state { + case .awaiting(var awaitings): + let awaiting = awaitings.remove(.placeHolder(id: awaitingId)) + if awaitings.isEmpty { + state = .idle + } else { + state = .awaiting(awaitings) + } + return awaiting + default: + return nil + } + } + + awaiting?.continuation?.resume(returning: nil) } } diff --git a/Sources/Combiners/Merge/MergeStateMachine.swift b/Sources/Combiners/Merge/MergeStateMachine.swift index dd848b1..2f7984a 100644 --- a/Sources/Combiners/Merge/MergeStateMachine.swift +++ b/Sources/Combiners/Merge/MergeStateMachine.swift @@ -197,8 +197,6 @@ struct MergeStateMachine: Sendable { func next() async -> RegulatedElement { await withTaskCancellationHandler { - self.unsuspendAndClearOnCancel() - } operation: { self.requestNextRegulatedElements() let regulatedElement = await withUnsafeContinuation { (continuation: UnsafeContinuation, Never>) in @@ -245,6 +243,8 @@ struct MergeStateMachine: Sendable { } return regulatedElement + } onCancel: { + self.unsuspendAndClearOnCancel() } } } diff --git a/Sources/Combiners/WithLatestFrom/AsyncWithLatestFrom2Sequence.swift b/Sources/Combiners/WithLatestFrom/AsyncWithLatestFrom2Sequence.swift index d10fe9f..b5007df 100644 --- a/Sources/Combiners/WithLatestFrom/AsyncWithLatestFrom2Sequence.swift +++ b/Sources/Combiners/WithLatestFrom/AsyncWithLatestFrom2Sequence.swift @@ -172,12 +172,7 @@ where Other1: Sendable, Other2: Sendable, Other1.Element: Sendable, Other2.Eleme let shouldReturnNil = self.isTerminated.withCriticalRegion { $0 } guard !shouldReturnNil else { return nil } - return try await withTaskCancellationHandler { [isTerminated, othersTask] in - isTerminated.withCriticalRegion { isTerminated in - isTerminated = true - } - othersTask?.cancel() - } operation: { [othersTask, othersState, onBaseElement] in + return try await withTaskCancellationHandler { [othersTask, othersState, onBaseElement] in do { while true { guard let baseElement = try await self.base.next() else { @@ -219,6 +214,11 @@ where Other1: Sendable, Other2: Sendable, Other1.Element: Sendable, Other2.Eleme othersTask?.cancel() throw error } + } onCancel: { [isTerminated, othersTask] in + isTerminated.withCriticalRegion { isTerminated in + isTerminated = true + } + othersTask?.cancel() } } } diff --git a/Sources/Combiners/WithLatestFrom/AsyncWithLatestFromSequence.swift b/Sources/Combiners/WithLatestFrom/AsyncWithLatestFromSequence.swift index d02d44f..2eb9059 100644 --- a/Sources/Combiners/WithLatestFrom/AsyncWithLatestFromSequence.swift +++ b/Sources/Combiners/WithLatestFrom/AsyncWithLatestFromSequence.swift @@ -121,9 +121,7 @@ where Other: Sendable, Other.Element: Sendable { public mutating func next() async rethrows -> Element? { guard !self.isTerminated else { return nil } - return try await withTaskCancellationHandler { [otherTask] in - otherTask?.cancel() - } operation: { [otherTask, otherState, onBaseElement] in + return try await withTaskCancellationHandler { [otherTask, otherState, onBaseElement] in do { while true { guard let baseElement = try await self.base.next() else { @@ -157,6 +155,8 @@ where Other: Sendable, Other.Element: Sendable { otherTask?.cancel() throw error } + } onCancel: { [otherTask] in + otherTask?.cancel() } } } diff --git a/Sources/Combiners/Zip/Zip2Runtime.swift b/Sources/Combiners/Zip/Zip2Runtime.swift index 699b8ca..f59cbb6 100644 --- a/Sources/Combiners/Zip/Zip2Runtime.swift +++ b/Sources/Combiners/Zip/Zip2Runtime.swift @@ -148,12 +148,6 @@ where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: func next() async rethrows -> (Base1.Element, Base2.Element)? { try await withTaskCancellationHandler { - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.rootTaskIsCancelled() - } - - self.handle(rootTaskIsCancelledOutput: output) - } operation: { let results = await withUnsafeContinuation { (continuation: UnsafeContinuation<(Result, Result)?, Never>) in let output = self.stateMachine.withCriticalRegion { stateMachine in stateMachine.newDemandFromConsumer(suspendedDemand: continuation) @@ -173,6 +167,12 @@ where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: self.handle(demandIsFulfilledOutput: output) return try (results.0._rethrowGet(), results.1._rethrowGet()) + } onCancel: { + let output = self.stateMachine.withCriticalRegion { stateMachine in + stateMachine.rootTaskIsCancelled() + } + + self.handle(rootTaskIsCancelledOutput: output) } } diff --git a/Sources/Combiners/Zip/Zip3Runtime.swift b/Sources/Combiners/Zip/Zip3Runtime.swift index eeb98f9..9b8cc18 100644 --- a/Sources/Combiners/Zip/Zip3Runtime.swift +++ b/Sources/Combiners/Zip/Zip3Runtime.swift @@ -186,12 +186,6 @@ where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: func next() async rethrows -> (Base1.Element, Base2.Element, Base3.Element)? { try await withTaskCancellationHandler { - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.rootTaskIsCancelled() - } - - self.handle(rootTaskIsCancelledOutput: output) - } operation: { let results = await withUnsafeContinuation { (continuation: UnsafeContinuation<(Result, Result, Result)?, Never>) in let output = self.stateMachine.withCriticalRegion { stateMachine in stateMachine.newDemandFromConsumer(suspendedDemand: continuation) @@ -211,6 +205,12 @@ where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: self.handle(demandIsFulfilledOutput: output) return try (results.0._rethrowGet(), results.1._rethrowGet(), results.2._rethrowGet()) + } onCancel: { + let output = self.stateMachine.withCriticalRegion { stateMachine in + stateMachine.rootTaskIsCancelled() + } + + self.handle(rootTaskIsCancelledOutput: output) } } diff --git a/Sources/Combiners/Zip/ZipRuntime.swift b/Sources/Combiners/Zip/ZipRuntime.swift index be04f9d..1fa4df8 100644 --- a/Sources/Combiners/Zip/ZipRuntime.swift +++ b/Sources/Combiners/Zip/ZipRuntime.swift @@ -120,12 +120,6 @@ where Base: Sendable, Base.Element: Sendable { func next() async rethrows -> [Base.Element]? { try await withTaskCancellationHandler { - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.rootTaskIsCancelled() - } - - self.handle(rootTaskIsCancelledOutput: output) - } operation: { let results = await withUnsafeContinuation { (continuation: UnsafeContinuation<[Int: Result]?, Never>) in let output = self.stateMachine.withCriticalRegion { stateMachine in stateMachine.newDemandFromConsumer(suspendedDemand: continuation) @@ -145,6 +139,12 @@ where Base: Sendable, Base.Element: Sendable { self.handle(demandIsFulfilledOutput: output) return try results.sorted { $0.key < $1.key }.map { try $0.value._rethrowGet() } + } onCancel: { + let output = self.stateMachine.withCriticalRegion { stateMachine in + stateMachine.rootTaskIsCancelled() + } + + self.handle(rootTaskIsCancelledOutput: output) } } diff --git a/Sources/Creators/AsyncTimerSequence.swift b/Sources/Creators/AsyncTimerSequence.swift index 5633d82..a0d6146 100644 --- a/Sources/Creators/AsyncTimerSequence.swift +++ b/Sources/Creators/AsyncTimerSequence.swift @@ -78,11 +78,11 @@ public struct AsyncTimerSequence: AsyncSequence { } public mutating func next() async -> Element? { - await withTaskCancellationHandler { [task] in - task.cancel() - } operation: { + await withTaskCancellationHandler { guard !Task.isCancelled else { return nil } return await self.iterator.next() + } onCancel: { [task] in + task.cancel() } } } diff --git a/Sources/Operators/AsyncMulticastSequence.swift b/Sources/Operators/AsyncMulticastSequence.swift index f106069..1cc4971 100644 --- a/Sources/Operators/AsyncMulticastSequence.swift +++ b/Sources/Operators/AsyncMulticastSequence.swift @@ -101,47 +101,42 @@ where Base.Element == Subject.Element, Subject.Failure == Error, Base.AsyncItera /// Allow the `AsyncIterator` to produce elements. public func connect() { + self.isConnected.apply(criticalState: true) self.connectedGate.send(()) } func next() async { - guard !Task.isCancelled else { return } - - let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in - switch state { - case .available(let iterator): - state = .busy - return (true, iterator) - case .busy: - return (false, nil) + await Task { + let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in + switch state { + case .available(let iterator): + state = .busy + return (true, iterator) + case .busy: + return (false, nil) + } + } + + guard canAccessBase, var iterator = iterator else { return } + + let toSend: Result + do { + let element = try await iterator.next() + toSend = .success(element) + } catch { + toSend = .failure(error) } - } - guard canAccessBase, var iterator = iterator else { return } - defer { self.state.withCriticalRegion { state in state = .available(iterator) } - } - - guard !Task.isCancelled else { return } - let toSend: Result - do { - let element = try await iterator.next() - toSend = .success(element) - } catch { - guard !Task.isCancelled else { return } - toSend = .failure(error) - } - - guard !Task.isCancelled else { return } - - switch toSend { - case .success(.some(let element)): self.subject.send(element) - case .success(.none): self.subject.send(.finished) - case .failure(let error): self.subject.send(.failure(error)) - } + switch toSend { + case .success(.some(let element)): self.subject.send(element) + case .success(.none): self.subject.send(.finished) + case .failure(let error): self.subject.send(.failure(error)) + } + }.value } public func makeAsyncIterator() -> AsyncIterator { @@ -163,14 +158,8 @@ where Base.Element == Subject.Element, Subject.Failure == Error, Base.AsyncItera public mutating func next() async rethrows -> Element? { guard !Task.isCancelled else { return nil } - let shouldWaitForGate = self.isConnected.withCriticalRegion { isConnected -> Bool in - if !isConnected { - isConnected = true - return true - } - return false - } - if shouldWaitForGate { + let isConnected = self.isConnected.withCriticalRegion { $0 } + if !isConnected { await self.connectedGateIterator.next() } diff --git a/Sources/Operators/AsyncPrependSequence.swift b/Sources/Operators/AsyncPrependSequence.swift index dbf6b4b..0935fed 100644 --- a/Sources/Operators/AsyncPrependSequence.swift +++ b/Sources/Operators/AsyncPrependSequence.swift @@ -54,12 +54,12 @@ public struct AsyncPrependSequence: AsyncSequence { public struct Iterator: AsyncIteratorProtocol { var base: Base.AsyncIterator - var prependElement: () async throws -> Element + var prependElement: @Sendable () -> Element var hasBeenDelivered = false public init( base: Base.AsyncIterator, - prependElement: @escaping () async throws -> Element + prependElement: @Sendable @escaping () -> Element ) { self.base = base self.prependElement = prependElement @@ -70,7 +70,7 @@ public struct AsyncPrependSequence: AsyncSequence { if !self.hasBeenDelivered { self.hasBeenDelivered = true - return try await prependElement() + return prependElement() } return try await self.base.next() diff --git a/Sources/Operators/AsyncSwitchToLatestSequence.swift b/Sources/Operators/AsyncSwitchToLatestSequence.swift index 736df25..30147b8 100644 --- a/Sources/Operators/AsyncSwitchToLatestSequence.swift +++ b/Sources/Operators/AsyncSwitchToLatestSequence.swift @@ -221,12 +221,7 @@ where Base.Element: AsyncSequence, Base: Sendable, Base.Element.Element: Sendabl guard !Task.isCancelled else { return nil } self.startBase() - return try await withTaskCancellationHandler { [baseTask, state] in - baseTask?.cancel() - state.withCriticalRegion { - $0.childTask?.cancel() - } - } operation: { + return try await withTaskCancellationHandler { while true { let childTask = await withUnsafeContinuation { [state] (continuation: UnsafeContinuation?, Never>) in let decision = state.withCriticalRegion { state -> NextDecision in @@ -303,6 +298,11 @@ where Base.Element: AsyncSequence, Base: Sendable, Base.Element.Element: Sendabl return try element._rethrowGet() } } + } onCancel: { [baseTask, state] in + baseTask?.cancel() + state.withCriticalRegion { + $0.childTask?.cancel() + } } } } diff --git a/Sources/Supporting/Regulator.swift b/Sources/Supporting/Regulator.swift index 1e11f2b..456469f 100644 --- a/Sources/Supporting/Regulator.swift +++ b/Sources/Supporting/Regulator.swift @@ -48,8 +48,6 @@ final class Regulator: @unchecked Sendable { func iterate() async { await withTaskCancellationHandler { - self.unsuspendAndExitOnCancel() - } operation: { var mutableBase = base.makeAsyncIterator() do { @@ -99,6 +97,8 @@ final class Regulator: @unchecked Sendable { } self.onNextRegulatedElement(.element(result: .failure(error))) } + } onCancel: { + self.unsuspendAndExitOnCancel() } } diff --git a/Tests/AsyncChannels/AsyncBufferedChannelTests.swift b/Tests/AsyncChannels/AsyncBufferedChannelTests.swift index a032c89..1476e76 100644 --- a/Tests/AsyncChannels/AsyncBufferedChannelTests.swift +++ b/Tests/AsyncChannels/AsyncBufferedChannelTests.swift @@ -46,7 +46,7 @@ final class AsyncBufferedChannelTests: XCTestCase { return received } - wait(for: [iterationIsAwaiting], timeout: 1.0) + await fulfillment(of: [iterationIsAwaiting], timeout: 1.0) // When sut.send(1) @@ -125,19 +125,19 @@ final class AsyncBufferedChannelTests: XCTestCase { for await element in sut { received = element taskCanBeCancelled.fulfill() - wait(for: [taskWasCancelled], timeout: 1.0) + await fulfillment(of: [taskWasCancelled], timeout: 1.0) } iterationHasFinished.fulfill() return received } - wait(for: [taskCanBeCancelled], timeout: 1.0) + await fulfillment(of: [taskCanBeCancelled], timeout: 1.0) // When task.cancel() taskWasCancelled.fulfill() - wait(for: [iterationHasFinished], timeout: 1.0) + await fulfillment(of: [iterationHasFinished], timeout: 1.0) // Then let received = await task.value @@ -170,12 +170,12 @@ final class AsyncBufferedChannelTests: XCTestCase { return received } - wait(for: [iteration1IsAwaiting, iteration2IsAwaiting], timeout: 1.0) + await fulfillment(of: [iteration1IsAwaiting, iteration2IsAwaiting], timeout: 1.0) // When sut.finish() - wait(for: [iteration1IsFinished, iteration2IsFinished], timeout: 1.0) + await fulfillment(of: [iteration1IsFinished, iteration2IsFinished], timeout: 1.0) let received1 = await task1.value let received2 = await task2.value @@ -217,7 +217,7 @@ final class AsyncBufferedChannelTests: XCTestCase { }.cancel() // Then - wait(for: [iterationIsFinished], timeout: 1.0) + await fulfillment(of: [iterationIsFinished], timeout: 1.0) } func test_awaiting_uses_id_for_equatable() { diff --git a/Tests/AsyncChannels/AsyncBufferedThrowingChannelTests.swift b/Tests/AsyncChannels/AsyncBufferedThrowingChannelTests.swift index 7bef3ec..1b01c9d 100644 --- a/Tests/AsyncChannels/AsyncBufferedThrowingChannelTests.swift +++ b/Tests/AsyncChannels/AsyncBufferedThrowingChannelTests.swift @@ -46,7 +46,7 @@ final class AsyncBufferedThrowingChannelTests: XCTestCase { return received } - wait(for: [iterationIsAwaiting], timeout: 1.0) + await fulfillment(of: [iterationIsAwaiting], timeout: 1.0) // When sut.send(1) @@ -125,19 +125,19 @@ final class AsyncBufferedThrowingChannelTests: XCTestCase { for try await element in sut { received = element taskCanBeCancelled.fulfill() - wait(for: [taskWasCancelled], timeout: 1.0) + await fulfillment(of: [taskWasCancelled], timeout: 1.0) } iterationHasFinished.fulfill() return received } - wait(for: [taskCanBeCancelled], timeout: 1.0) + await fulfillment(of: [taskCanBeCancelled], timeout: 1.0) // When task.cancel() taskWasCancelled.fulfill() - wait(for: [iterationHasFinished], timeout: 1.0) + await fulfillment(of: [iterationHasFinished], timeout: 1.0) // Then let received = try await task.value @@ -211,13 +211,13 @@ final class AsyncBufferedThrowingChannelTests: XCTestCase { } } - wait(for: [iteration1IsAwaiting, iteration2IsAwaiting], timeout: 1.0) + await fulfillment(of: [iteration1IsAwaiting, iteration2IsAwaiting], timeout: 1.0) // When sut.fail(MockError(code: 1701)) // Then - wait(for: [iteration1HasThrown, iteration2HasThrown], timeout: 1.0) + await fulfillment(of: [iteration1HasThrown, iteration2HasThrown], timeout: 1.0) let iterator = sut.makeAsyncIterator() do { @@ -254,12 +254,12 @@ final class AsyncBufferedThrowingChannelTests: XCTestCase { return received } - wait(for: [iteration1IsAwaiting, iteration2IsAwaiting], timeout: 1.0) + await fulfillment(of: [iteration1IsAwaiting, iteration2IsAwaiting], timeout: 1.0) // When sut.finish() - wait(for: [iteration1IsFinished, iteration2IsFinished], timeout: 1.0) + await fulfillment(of: [iteration1IsFinished, iteration2IsFinished], timeout: 1.0) let received1 = try await task1.value let received2 = try await task2.value @@ -301,7 +301,7 @@ final class AsyncBufferedThrowingChannelTests: XCTestCase { }.cancel() // Then - wait(for: [iterationIsFinished], timeout: 1.0) + await fulfillment(of: [iterationIsFinished], timeout: 1.0) } func test_awaiting_uses_id_for_equatable() { diff --git a/Tests/AsyncSubjets/AsyncCurrentValueSubjectTests.swift b/Tests/AsyncSubjets/AsyncCurrentValueSubjectTests.swift index cf6a9c8..3d2d9e2 100644 --- a/Tests/AsyncSubjets/AsyncCurrentValueSubjectTests.swift +++ b/Tests/AsyncSubjets/AsyncCurrentValueSubjectTests.swift @@ -107,11 +107,11 @@ final class AsyncCurrentValueSubjectTests: XCTestCase { hasFinishedExpectation.fulfill() } - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send(.finished) - wait(for: [hasFinishedExpectation], timeout: 1) + await fulfillment(of: [hasFinishedExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() let received = await iterator.next() @@ -130,7 +130,7 @@ final class AsyncCurrentValueSubjectTests: XCTestCase { for await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 1) taskHasFinishedExpectation.fulfill() @@ -175,7 +175,7 @@ final class AsyncCurrentValueSubjectTests: XCTestCase { return received.sorted() } - await waitForExpectations(timeout: 1) + await fulfillment(of: [canSendExpectation], timeout: 1) // concurrently push values in the sut 1 let task1 = Task { diff --git a/Tests/AsyncSubjets/AsyncPassthroughSubjectTests.swift b/Tests/AsyncSubjets/AsyncPassthroughSubjectTests.swift index 2cbac1c..728ed28 100644 --- a/Tests/AsyncSubjets/AsyncPassthroughSubjectTests.swift +++ b/Tests/AsyncSubjets/AsyncPassthroughSubjectTests.swift @@ -91,15 +91,15 @@ final class AsyncPassthroughSubjectTests: XCTestCase { hasFinishedExpectation.fulfill() } - wait(for: [isReadyToBeIteratedExpectation], timeout: 1) + await fulfillment(of: [isReadyToBeIteratedExpectation], timeout: 1) sut.send(1) - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send( .finished) - wait(for: [hasFinishedExpectation], timeout: 1) + await fulfillment(of: [hasFinishedExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() let received = await iterator.next() @@ -122,7 +122,7 @@ final class AsyncPassthroughSubjectTests: XCTestCase { while let element = await it.next() { receivedElements.append(element) canCancelExpectation.fulfill() - wait(for: [hasCancelExpectation], timeout: 5) + await fulfillment(of: [hasCancelExpectation], timeout: 5) } XCTAssertEqual(receivedElements, [1]) taskHasFinishedExpectation.fulfill() @@ -171,7 +171,7 @@ final class AsyncPassthroughSubjectTests: XCTestCase { return received.sorted() } - await waitForExpectations(timeout: 1) + await fulfillment(of: [canSendExpectation], timeout: 1) // concurrently push values in the sut 1 let task1 = Task { diff --git a/Tests/AsyncSubjets/AsyncReplaySubjectTests.swift b/Tests/AsyncSubjets/AsyncReplaySubjectTests.swift index e4a3857..0f824fc 100644 --- a/Tests/AsyncSubjets/AsyncReplaySubjectTests.swift +++ b/Tests/AsyncSubjets/AsyncReplaySubjectTests.swift @@ -130,11 +130,11 @@ final class AsyncReplaySubjectTests: XCTestCase { hasFinishedExpectation.fulfill() } - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send(.finished) - wait(for: [hasFinishedExpectation], timeout: 1) + await fulfillment(of: [hasFinishedExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() let received = await iterator.next() @@ -155,7 +155,7 @@ final class AsyncReplaySubjectTests: XCTestCase { for await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 1) taskHasFinishedExpectation.fulfill() @@ -200,7 +200,7 @@ final class AsyncReplaySubjectTests: XCTestCase { return received.sorted() } - await waitForExpectations(timeout: 1) + await fulfillment(of: [canSendExpectation], timeout: 1) // concurrently push values in the sut 1 let task1 = Task { diff --git a/Tests/AsyncSubjets/AsyncThrowingCurrentValueSubjectTests.swift b/Tests/AsyncSubjets/AsyncThrowingCurrentValueSubjectTests.swift index 2b323d0..c2cfba0 100644 --- a/Tests/AsyncSubjets/AsyncThrowingCurrentValueSubjectTests.swift +++ b/Tests/AsyncSubjets/AsyncThrowingCurrentValueSubjectTests.swift @@ -107,11 +107,11 @@ final class AsyncThrowingCurrentValueSubjectTests: XCTestCase { hasFinishedExpectation.fulfill() } - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send(.finished) - wait(for: [hasFinishedExpectation], timeout: 1) + await fulfillment(of: [hasFinishedExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() let received = try await iterator.next() @@ -155,11 +155,11 @@ final class AsyncThrowingCurrentValueSubjectTests: XCTestCase { } } - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send(.failure(expectedError)) - wait(for: [hasFinishedWithFailureExpectation], timeout: 1) + await fulfillment(of: [hasFinishedWithFailureExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() do { @@ -182,7 +182,7 @@ final class AsyncThrowingCurrentValueSubjectTests: XCTestCase { for try await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 1) taskHasFinishedExpectation.fulfill() @@ -227,7 +227,7 @@ final class AsyncThrowingCurrentValueSubjectTests: XCTestCase { return received.sorted() } - await waitForExpectations(timeout: 1) + await fulfillment(of: [canSendExpectation], timeout: 1) // concurrently push values in the sut 1 let task1 = Task { diff --git a/Tests/AsyncSubjets/AsyncThrowingPassthroughSubjectTests.swift b/Tests/AsyncSubjets/AsyncThrowingPassthroughSubjectTests.swift index 8546868..e9838dc 100644 --- a/Tests/AsyncSubjets/AsyncThrowingPassthroughSubjectTests.swift +++ b/Tests/AsyncSubjets/AsyncThrowingPassthroughSubjectTests.swift @@ -91,15 +91,15 @@ final class AsyncThrowingPassthroughSubjectTests: XCTestCase { hasFinishedExpectation.fulfill() } - wait(for: [isReadyToBeIteratedExpectation], timeout: 1) + await fulfillment(of: [isReadyToBeIteratedExpectation], timeout: 1) sut.send(1) - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send( .finished) - wait(for: [hasFinishedExpectation], timeout: 1) + await fulfillment(of: [hasFinishedExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() let received = try await iterator.next() @@ -150,15 +150,15 @@ final class AsyncThrowingPassthroughSubjectTests: XCTestCase { } } - wait(for: [isReadyToBeIteratedExpectation], timeout: 1) + await fulfillment(of: [isReadyToBeIteratedExpectation], timeout: 1) sut.send(1) - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send(.failure(expectedError)) - wait(for: [hasFinishedWithFailureExpectation], timeout: 1) + await fulfillment(of: [hasFinishedWithFailureExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() do { @@ -185,7 +185,7 @@ final class AsyncThrowingPassthroughSubjectTests: XCTestCase { while let element = try await it.next() { receivedElements.append(element) canCancelExpectation.fulfill() - wait(for: [hasCancelExpectation], timeout: 5) + await fulfillment(of: [hasCancelExpectation], timeout: 5) } XCTAssertEqual(receivedElements, [1]) taskHasFinishedExpectation.fulfill() @@ -234,7 +234,7 @@ final class AsyncThrowingPassthroughSubjectTests: XCTestCase { return received.sorted() } - await waitForExpectations(timeout: 1) + await fulfillment(of: [canSendExpectation], timeout: 1) // concurrently push values in the sut 1 let task1 = Task { diff --git a/Tests/AsyncSubjets/AsyncThrowingReplaySubjectTests.swift b/Tests/AsyncSubjets/AsyncThrowingReplaySubjectTests.swift index df5edde..222ef5a 100644 --- a/Tests/AsyncSubjets/AsyncThrowingReplaySubjectTests.swift +++ b/Tests/AsyncSubjets/AsyncThrowingReplaySubjectTests.swift @@ -130,11 +130,11 @@ final class AsyncThrowingReplaySubjectTests: XCTestCase { hasFinishedExpectation.fulfill() } - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send(.finished) - wait(for: [hasFinishedExpectation], timeout: 1) + await fulfillment(of: [hasFinishedExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() let received = try await iterator.next() @@ -180,11 +180,11 @@ final class AsyncThrowingReplaySubjectTests: XCTestCase { } } - wait(for: [hasReceivedOneElementExpectation], timeout: 1) + await fulfillment(of: [hasReceivedOneElementExpectation], timeout: 1) sut.send(.failure(expectedError)) - wait(for: [hasFinishedWithFailureExpectation], timeout: 1) + await fulfillment(of: [hasFinishedWithFailureExpectation], timeout: 1) var iterator = sut.makeAsyncIterator() do { @@ -209,7 +209,7 @@ final class AsyncThrowingReplaySubjectTests: XCTestCase { for try await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 1) taskHasFinishedExpectation.fulfill() @@ -254,7 +254,7 @@ final class AsyncThrowingReplaySubjectTests: XCTestCase { return received.sorted() } - await waitForExpectations(timeout: 1) + await fulfillment(of: [canSendExpectation], timeout: 1) // concurrently push values in the sut 1 let task1 = Task { diff --git a/Tests/Combiners/WithLatestFrom/AsyncWithLatestFrom2SequenceTests.swift b/Tests/Combiners/WithLatestFrom/AsyncWithLatestFrom2SequenceTests.swift index 8a198da..915e042 100644 --- a/Tests/Combiners/WithLatestFrom/AsyncWithLatestFrom2SequenceTests.swift +++ b/Tests/Combiners/WithLatestFrom/AsyncWithLatestFrom2SequenceTests.swift @@ -60,11 +60,11 @@ final class AsyncWithLatestFrom2SequenceTests: XCTestCase { Task { base.send(0) - wait(for: [baseHasProduced0], timeout: 1.0) + await fulfillment(of: [baseHasProduced0], timeout: 1.0) other1.send("a") - wait(for: [other1HasProducedA], timeout: 1.0) + await fulfillment(of: [other1HasProducedA], timeout: 1.0) other2.send("x") - wait(for: [other2HasProducedX], timeout: 1.0) + await fulfillment(of: [other2HasProducedX], timeout: 1.0) base.send(1) } @@ -73,7 +73,7 @@ final class AsyncWithLatestFrom2SequenceTests: XCTestCase { Task { other2.send("y") - wait(for: [other2HasProducedY], timeout: 1.0) + await fulfillment(of: [other2HasProducedY], timeout: 1.0) base.send(2) } @@ -82,7 +82,7 @@ final class AsyncWithLatestFrom2SequenceTests: XCTestCase { Task { other1.send("b") - wait(for: [other1HasProducedB], timeout: 1.0) + await fulfillment(of: [other1HasProducedB], timeout: 1.0) base.send(3) } @@ -209,11 +209,11 @@ final class AsyncWithLatestFrom2SequenceTests: XCTestCase { } // ensure the other task actually starts - wait(for: [iterated], timeout: 5.0) + await fulfillment(of: [iterated], timeout: 5.0) // cancellation should ensure the loop finishes // without regards to the remaining underlying sequence task.cancel() - wait(for: [finished], timeout: 5.0) + await fulfillment(of: [finished], timeout: 5.0) } } diff --git a/Tests/Combiners/WithLatestFrom/AsyncWithLatestFromSequenceTests.swift b/Tests/Combiners/WithLatestFrom/AsyncWithLatestFromSequenceTests.swift index 2455df1..732673f 100644 --- a/Tests/Combiners/WithLatestFrom/AsyncWithLatestFromSequenceTests.swift +++ b/Tests/Combiners/WithLatestFrom/AsyncWithLatestFromSequenceTests.swift @@ -46,9 +46,9 @@ final class AsyncWithLatestFromSequenceTests: XCTestCase { Task { base.send(0) - wait(for: [baseHasProduced0], timeout: 1.0) + await fulfillment(of: [baseHasProduced0], timeout: 1.0) other.send("a") - wait(for: [otherHasProducedA], timeout: 1.0) + await fulfillment(of: [otherHasProducedA], timeout: 1.0) base.send(1) } @@ -63,7 +63,7 @@ final class AsyncWithLatestFromSequenceTests: XCTestCase { Task { other.send("b") other.send("c") - wait(for: [otherHasProducedC], timeout: 1.0) + await fulfillment(of: [otherHasProducedC], timeout: 1.0) base.send(3) } @@ -151,11 +151,11 @@ final class AsyncWithLatestFromSequenceTests: XCTestCase { } // ensure the other task actually starts - wait(for: [iterated], timeout: 1.0) + await fulfillment(of: [iterated], timeout: 1.0) // cancellation should ensure the loop finishes // without regards to the remaining underlying sequence task.cancel() - wait(for: [finished], timeout: 1.0) + await fulfillment(of: [finished], timeout: 1.0) } } diff --git a/Tests/Combiners/Zip/AsyncZipSequenceTests.swift b/Tests/Combiners/Zip/AsyncZipSequenceTests.swift index 701bd48..68eafaa 100644 --- a/Tests/Combiners/Zip/AsyncZipSequenceTests.swift +++ b/Tests/Combiners/Zip/AsyncZipSequenceTests.swift @@ -140,7 +140,7 @@ final class AsyncZipSequenceTests: XCTestCase { for try await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement!.0, 1) XCTAssertEqual(firstElement!.1, "1") @@ -313,7 +313,7 @@ extension AsyncZipSequenceTests { for try await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement!.0, 1) // the AsyncSequence is cancelled having only emitted the first element XCTAssertEqual(firstElement!.1, "1") @@ -398,7 +398,7 @@ extension AsyncZipSequenceTests { for await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement!, [1, 1, 1, 1, 1]) taskHasFinishedExpectation.fulfill() diff --git a/Tests/Creators/AsyncFailSequenceTests.swift b/Tests/Creators/AsyncFailSequenceTests.swift index 24790fb..4f55f53 100644 --- a/Tests/Creators/AsyncFailSequenceTests.swift +++ b/Tests/Creators/AsyncFailSequenceTests.swift @@ -41,7 +41,7 @@ final class AsyncFailSequenceTests: XCTestCase { let task = Task { do { var iterator = failSequence.makeAsyncIterator() - wait(for: [taskHasBeenCancelledExpectation], timeout: 1) + await fulfillment(of: [taskHasBeenCancelledExpectation], timeout: 1) while let _ = try await iterator.next() { XCTFail("The AsyncSequence should not output elements") } diff --git a/Tests/Creators/AsyncJustSequenceTests.swift b/Tests/Creators/AsyncJustSequenceTests.swift index 407fb1a..731c634 100644 --- a/Tests/Creators/AsyncJustSequenceTests.swift +++ b/Tests/Creators/AsyncJustSequenceTests.swift @@ -29,7 +29,7 @@ final class AsyncJustSequenceTests: XCTestCase { let justSequence = AsyncJustSequence(1) let task = Task { - wait(for: [hasCancelledExpectation], timeout: 1) + await fulfillment(of: [hasCancelledExpectation], timeout: 1) for await _ in justSequence { XCTFail("The AsyncSequence should not output elements") } diff --git a/Tests/Creators/AsyncLazySequenceTests.swift b/Tests/Creators/AsyncLazySequenceTests.swift index f164fa9..0f4e941 100644 --- a/Tests/Creators/AsyncLazySequenceTests.swift +++ b/Tests/Creators/AsyncLazySequenceTests.swift @@ -36,7 +36,7 @@ final class AsyncLazySequenceTests: XCTestCase { for await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement!, 0) // the AsyncSequence is cancelled having only emitted the first element } diff --git a/Tests/Creators/AsyncThrowingJustSequenceTests.swift b/Tests/Creators/AsyncThrowingJustSequenceTests.swift index aa07089..d5fe8be 100644 --- a/Tests/Creators/AsyncThrowingJustSequenceTests.swift +++ b/Tests/Creators/AsyncThrowingJustSequenceTests.swift @@ -54,7 +54,7 @@ final class AsyncThrowingJustSequenceTests: XCTestCase { let justSequence = AsyncThrowingJustSequence(1) let task = Task { - wait(for: [hasCancelledExpectation], timeout: 1) + await fulfillment(of: [hasCancelledExpectation], timeout: 1) for try await _ in justSequence { XCTFail("The AsyncSequence should not output elements") } diff --git a/Tests/Operators/AsyncHandleEventsSequenceTests.swift b/Tests/Operators/AsyncHandleEventsSequenceTests.swift index 52b90c3..eedd8d4 100644 --- a/Tests/Operators/AsyncHandleEventsSequenceTests.swift +++ b/Tests/Operators/AsyncHandleEventsSequenceTests.swift @@ -53,7 +53,7 @@ final class AsyncHandleEventsSequenceTests: XCTestCase { firstElementHasBeenReceivedExpectation.fulfill() } - wait(for: [taskHasBeenCancelledExpectation], timeout: 1) + await fulfillment(of: [taskHasBeenCancelledExpectation], timeout: 1) } } @@ -95,7 +95,7 @@ final class AsyncHandleEventsSequenceTests: XCTestCase { XCTAssertEqual(error as? MockError, expectedError) } - await waitForExpectations(timeout: 1) + await fulfillment(of: [onFinishHasBeenCalledExpectation], timeout: 1) } func test_iteration_finishes_when_task_is_cancelled() { @@ -112,7 +112,7 @@ final class AsyncHandleEventsSequenceTests: XCTestCase { for try await element in handledSequence { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 0) taskHasFinishedExpectation.fulfill() diff --git a/Tests/Operators/AsyncMulticastSequenceTests.swift b/Tests/Operators/AsyncMulticastSequenceTests.swift index bd60335..ffe9828 100644 --- a/Tests/Operators/AsyncMulticastSequenceTests.swift +++ b/Tests/Operators/AsyncMulticastSequenceTests.swift @@ -39,39 +39,6 @@ private class SpyAsyncSequenceForNumberOfIterators: AsyncSequence { } } -private struct CancellationAwareSequence: AsyncSequence { - typealias Element = Int - typealias AsyncIterator = Iterator - - let onStart: @Sendable () -> Void - let onCancel: @Sendable () -> Void - - func makeAsyncIterator() -> AsyncIterator { - Iterator(onStart: self.onStart, onCancel: self.onCancel) - } - - struct Iterator: AsyncIteratorProtocol { - let onStart: @Sendable () -> Void - let onCancel: @Sendable () -> Void - var hasStarted = false - - mutating func next() async throws -> Int? { - if !hasStarted { - hasStarted = true - onStart() - } - - do { - try await Task.sleep(nanoseconds: 5_000_000_000) - return nil - } catch { - onCancel() - return nil - } - } - } -} - final class AsyncMulticastSequenceTests: XCTestCase { func test_multiple_loops_receive_elements_from_single_baseIterator() { let taskHaveIterators = expectation(description: "All tasks have their iterator") @@ -190,25 +157,4 @@ final class AsyncMulticastSequenceTests: XCTestCase { } } - func test_multicast_cancels_upstream_when_consumer_cancels() async { - let upstreamStartedExpectation = expectation(description: "Upstream started") - let upstreamCancelledExpectation = expectation(description: "Upstream cancelled") - - let base = CancellationAwareSequence( - onStart: { upstreamStartedExpectation.fulfill() }, - onCancel: { upstreamCancelledExpectation.fulfill() } - ) - let stream = AsyncThrowingPassthroughSubject() - let sut = base.multicast(stream).autoconnect() - - let task = Task { - var iterator = sut.makeAsyncIterator() - _ = try? await iterator.next() - } - - await fulfillment(of: [upstreamStartedExpectation], timeout: 1) - task.cancel() - - await fulfillment(of: [upstreamCancelledExpectation], timeout: 1) - } } diff --git a/Tests/Operators/AsyncPrependSequenceTests.swift b/Tests/Operators/AsyncPrependSequenceTests.swift index e520e06..c23c03d 100644 --- a/Tests/Operators/AsyncPrependSequenceTests.swift +++ b/Tests/Operators/AsyncPrependSequenceTests.swift @@ -37,7 +37,7 @@ final class AsyncPrependSequenceTests: XCTestCase { for try await element in prependedSequence { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 0) taskHasFinishedExpectation.fulfill() diff --git a/Tests/Operators/AsyncScanSequenceTests.swift b/Tests/Operators/AsyncScanSequenceTests.swift index 2833086..8eeb269 100644 --- a/Tests/Operators/AsyncScanSequenceTests.swift +++ b/Tests/Operators/AsyncScanSequenceTests.swift @@ -43,7 +43,7 @@ final class AsyncScanSequenceTests: XCTestCase { for try await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, "1") taskHasFinishedExpectation.fulfill() diff --git a/Tests/Operators/AsyncSequence+FlatMapLatestTests.swift b/Tests/Operators/AsyncSequence+FlatMapLatestTests.swift index a968baf..c84a637 100644 --- a/Tests/Operators/AsyncSequence+FlatMapLatestTests.swift +++ b/Tests/Operators/AsyncSequence+FlatMapLatestTests.swift @@ -199,7 +199,7 @@ final class AsyncSequence_FlatMapLatestTests: XCTestCase { for try await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 3) taskHasFinishedExpectation.fulfill() diff --git a/Tests/Operators/AsyncSwitchToLatestSequenceTests.swift b/Tests/Operators/AsyncSwitchToLatestSequenceTests.swift index bda619d..8619cde 100644 --- a/Tests/Operators/AsyncSwitchToLatestSequenceTests.swift +++ b/Tests/Operators/AsyncSwitchToLatestSequenceTests.swift @@ -40,15 +40,15 @@ private struct LongAsyncSequence: AsyncSequence, AsyncIteratorProtocol } mutating func next() async throws -> Element? { - return try await withTaskCancellationHandler { [onCancel] in - onCancel() - } operation: { + return try await withTaskCancellationHandler { try await Task.sleep(nanoseconds: self.interval.nanoseconds) self.currentIndex += 1 if self.currentIndex == self.failAt { throw MockError(code: 0) } return self.elements.next() + } onCancel: { [onCancel] in + onCancel() } } @@ -169,7 +169,7 @@ final class AsyncSwitchToLatestSequenceTests: XCTestCase { for try await element in sut { firstElement = element canCancelExpectation.fulfill() - wait(for: [hasCancelExceptation], timeout: 5) + await fulfillment(of: [hasCancelExceptation], timeout: 5) } XCTAssertEqual(firstElement, 3) taskHasFinishedExpectation.fulfill()