11#![ allow( clippy:: tabs_in_doc_comments) ]
2- #![ cfg_attr( feature = "unstable-thread-local" , feature( thread_local) ) ]
3- #![ cfg_attr( all( not( test) , feature = "unstable-thread-local" ) , no_std) ]
2+ #![ cfg_attr( not( test) , no_std) ]
43
4+ extern crate alloc;
55extern crate core;
66
7+ use alloc:: sync:: { Arc , Weak } ;
78use core:: {
89 cell:: Cell ,
910 future:: Future ,
1011 marker:: PhantomData ,
1112 pin:: Pin ,
12- ptr ,
13+ sync :: atomic :: { AtomicBool , Ordering } ,
1314 task:: { Context , Poll }
1415} ;
1516
@@ -19,91 +20,96 @@ use futures_core::stream::{FusedStream, Stream};
1920mod tests;
2021mod r#try;
2122
22- # [ cfg ( not ( feature = "unstable-thread-local" ) ) ]
23- thread_local ! {
24- static STORE : Cell <* mut ( ) > = const { Cell :: new ( ptr :: null_mut ( ) ) } ;
23+ pub ( crate ) struct SharedStore < T > {
24+ entered : AtomicBool ,
25+ cell : Cell < Option < T > >
2526}
26- #[ cfg( feature = "unstable-thread-local" ) ]
27- #[ thread_local]
28- static STORE : Cell < * mut ( ) > = Cell :: new ( ptr:: null_mut ( ) ) ;
2927
30- pub ( crate ) fn r#yield < T > ( value : T ) -> YieldFut < T > {
31- YieldFut { value : Some ( value) }
28+ impl < T > Default for SharedStore < T > {
29+ fn default ( ) -> Self {
30+ Self {
31+ entered : AtomicBool :: new ( false ) ,
32+ cell : Cell :: new ( None )
33+ }
34+ }
35+ }
36+
37+ impl < T > SharedStore < T > {
38+ pub fn has_value ( & self ) -> bool {
39+ unsafe { & * self . cell . as_ptr ( ) } . is_some ( )
40+ }
41+ }
42+
43+ unsafe impl < T > Sync for SharedStore < T > { }
44+
45+ pub struct Yielder < T > {
46+ pub ( crate ) store : Weak < SharedStore < T > >
47+ }
48+
49+ impl < T > Yielder < T > {
50+ pub fn r#yield ( & self , value : T ) -> YieldFut < ' _ , T > {
51+ #[ cold]
52+ fn invalid_usage ( ) -> ! {
53+ panic ! ( "attempted to use async_stream_lite yielder outside of stream context or across threads" )
54+ }
55+
56+ let Some ( store) = self . store . upgrade ( ) else {
57+ invalid_usage ( ) ;
58+ } ;
59+ if !store. entered . load ( Ordering :: Relaxed ) {
60+ invalid_usage ( ) ;
61+ }
62+
63+ store. cell . replace ( Some ( value) ) ;
64+
65+ YieldFut { store, _p : PhantomData }
66+ }
3267}
3368
3469/// Future returned by an [`AsyncStream`]'s yield function.
3570///
3671/// This future must be `.await`ed inside the generator in order for the item to be yielded by the stream.
3772#[ must_use = "stream will not yield this item unless the future returned by yield is awaited" ]
38- pub struct YieldFut < T > {
39- value : Option < T >
73+ pub struct YieldFut < ' y , T > {
74+ store : Arc < SharedStore < T > > ,
75+ _p : PhantomData < & ' y ( ) >
4076}
4177
42- impl < T > Unpin for YieldFut < T > { }
78+ impl < T > Unpin for YieldFut < ' _ , T > { }
4379
44- impl < T > Future for YieldFut < T > {
80+ impl < T > Future for YieldFut < ' _ , T > {
4581 type Output = ( ) ;
4682
47- fn poll ( mut self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
48- if self . value . is_none ( ) {
83+ fn poll ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
84+ if ! self . store . has_value ( ) {
4985 return Poll :: Ready ( ( ) ) ;
5086 }
5187
52- fn op < T > ( cell : & Cell < * mut ( ) > , value : & mut Option < T > ) {
53- let ptr = cell. get ( ) . cast :: < Option < T > > ( ) ;
54- let option_ref = unsafe { ptr. as_mut ( ) } . expect ( "attempted to use async_stream yielder outside of stream context or across threads" ) ;
55- if option_ref. is_none ( ) {
56- * option_ref = value. take ( ) ;
57- }
58- }
59-
60- #[ cfg( not( feature = "unstable-thread-local" ) ) ]
61- return STORE . with ( |cell| {
62- op ( cell, & mut self . value ) ;
63- Poll :: Pending
64- } ) ;
65- #[ cfg( feature = "unstable-thread-local" ) ]
66- {
67- op ( & STORE , & mut self . value ) ;
68- Poll :: Pending
69- }
88+ Poll :: Pending
7089 }
7190}
7291
73- struct Enter < ' a , T > {
74- _p : PhantomData < & ' a T > ,
75- prev : * mut ( )
92+ struct Enter < ' s , T > {
93+ store : & ' s SharedStore < T >
7694}
7795
78- fn enter < T > ( dst : & mut Option < T > ) -> Enter < ' _ , T > {
79- fn op < T > ( cell : & Cell < * mut ( ) > , dst : & mut Option < T > ) -> * mut ( ) {
80- let prev = cell. get ( ) ;
81- cell. set ( ( dst as * mut Option < T > ) . cast :: < ( ) > ( ) ) ;
82- prev
83- }
84- #[ cfg( not( feature = "unstable-thread-local" ) ) ]
85- let prev = STORE . with ( |cell| op ( cell, dst) ) ;
86- #[ cfg( feature = "unstable-thread-local" ) ]
87- let prev = op ( & STORE , dst) ;
88- Enter { _p : PhantomData , prev }
96+ fn enter < T > ( store : & SharedStore < T > ) -> Enter < ' _ , T > {
97+ store. entered . store ( true , Ordering :: Relaxed ) ;
98+ Enter { store }
8999}
90100
91101impl < T > Drop for Enter < ' _ , T > {
92102 fn drop ( & mut self ) {
93- #[ cfg( not( feature = "unstable-thread-local" ) ) ]
94- STORE . with ( |cell| cell. set ( self . prev ) ) ;
95- #[ cfg( feature = "unstable-thread-local" ) ]
96- STORE . set ( self . prev ) ;
103+ self . store . entered . store ( false , Ordering :: Relaxed ) ;
97104 }
98105}
99106
100107pin_project_lite:: pin_project! {
101108 /// A [`Stream`] created from an asynchronous generator-like function.
102109 ///
103110 /// To create an [`AsyncStream`], use the [`async_stream`] function.
104- #[ derive( Debug ) ]
105111 pub struct AsyncStream <T , U > {
106- _p : PhantomData < T >,
112+ store : Arc < SharedStore < T > >,
107113 done: bool ,
108114 #[ pin]
109115 generator: U
@@ -131,16 +137,15 @@ where
131137 return Poll :: Ready ( None ) ;
132138 }
133139
134- let mut dst = None ;
135140 let res = {
136- let _enter = enter ( & mut dst ) ;
141+ let _enter = enter ( & me . store ) ;
137142 me. generator . poll ( cx)
138143 } ;
139144
140145 * me. done = res. is_ready ( ) ;
141146
142- if dst . is_some ( ) {
143- return Poll :: Ready ( dst . take ( ) ) ;
147+ if me . store . has_value ( ) {
148+ return Poll :: Ready ( me . store . cell . take ( ) ) ;
144149 }
145150
146151 if * me. done { Poll :: Ready ( None ) } else { Poll :: Pending }
@@ -153,16 +158,16 @@ where
153158
154159/// Create an asynchronous [`Stream`] from an asynchronous generator function.
155160///
156- /// The provided function will be given a "yielder" function , which, when called, causes the stream to yield an item:
161+ /// The provided function will be given a [`Yielder`] , which, when called, causes the stream to yield an item:
157162/// ```
158163/// use async_stream_lite::async_stream;
159164/// use futures::{pin_mut, stream::StreamExt};
160165///
161166/// #[tokio::main]
162167/// async fn main() {
163- /// let stream = async_stream(|r#yield | async move {
168+ /// let stream = async_stream(|yielder | async move {
164169/// for i in 0..3 {
165- /// r#yield(i).await;
170+ /// yielder. r#yield(i).await;
166171/// }
167172/// });
168173/// pin_mut!(stream);
@@ -181,9 +186,9 @@ where
181186/// };
182187///
183188/// fn zero_to_three() -> impl Stream<Item = u32> {
184- /// async_stream(|r#yield | async move {
189+ /// async_stream(|yielder | async move {
185190/// for i in 0..3 {
186- /// r#yield(i).await;
191+ /// yielder. r#yield(i).await;
187192/// }
188193/// })
189194/// }
@@ -207,9 +212,9 @@ where
207212/// };
208213///
209214/// fn zero_to_three() -> BoxStream<'static, u32> {
210- /// Box::pin(async_stream(|r#yield | async move {
215+ /// Box::pin(async_stream(|yielder | async move {
211216/// for i in 0..3 {
212- /// r#yield(i).await;
217+ /// yielder. r#yield(i).await;
213218/// }
214219/// }))
215220/// }
@@ -232,18 +237,18 @@ where
232237/// };
233238///
234239/// fn zero_to_three() -> impl Stream<Item = u32> {
235- /// async_stream(|r#yield | async move {
240+ /// async_stream(|yielder | async move {
236241/// for i in 0..3 {
237- /// r#yield(i).await;
242+ /// yielder. r#yield(i).await;
238243/// }
239244/// })
240245/// }
241246///
242247/// fn double<S: Stream<Item = u32>>(input: S) -> impl Stream<Item = u32> {
243- /// async_stream(|r#yield | async move {
248+ /// async_stream(|yielder | async move {
244249/// pin_mut!(input);
245250/// while let Some(value) = input.next().await {
246- /// r#yield(value * 2).await;
251+ /// yielder. r#yield(value * 2).await;
247252/// }
248253/// })
249254/// }
@@ -261,15 +266,12 @@ where
261266/// See also [`try_async_stream`], a variant of [`async_stream`] which supports try notation (`?`).
262267pub fn async_stream < T , F , U > ( generator : F ) -> AsyncStream < T , U >
263268where
264- F : FnOnce ( fn ( value : T ) -> YieldFut < T > ) -> U ,
269+ F : FnOnce ( Yielder < T > ) -> U ,
265270 U : Future < Output = ( ) >
266271{
267- let generator = generator ( r#yield :: < T > ) ;
268- AsyncStream {
269- _p : PhantomData ,
270- done : false ,
271- generator
272- }
272+ let store = Arc :: new ( SharedStore :: default ( ) ) ;
273+ let generator = generator ( Yielder { store : Arc :: downgrade ( & store) } ) ;
274+ AsyncStream { store, done : false , generator }
273275}
274276
275277pub use self :: r#try:: { TryAsyncStream , try_async_stream} ;
0 commit comments