flume/
async.rs

1//! Futures and other types that allow asynchronous interaction with channels.
2
3use crate::*;
4use futures_core::{
5    future::FusedFuture,
6    stream::{FusedStream, Stream},
7};
8use futures_sink::Sink;
9use spin1::Mutex as Spinlock;
10use std::fmt::{Debug, Formatter};
11use std::{
12    any::Any,
13    future::Future,
14    ops::Deref,
15    pin::Pin,
16    task::{Context, Poll, Waker},
17};
18
19struct AsyncSignal {
20    waker: Spinlock<Waker>,
21    woken: AtomicBool,
22    stream: bool,
23}
24
25impl AsyncSignal {
26    fn new(cx: &Context, stream: bool) -> Self {
27        AsyncSignal {
28            waker: Spinlock::new(cx.waker().clone()),
29            woken: AtomicBool::new(false),
30            stream,
31        }
32    }
33}
34
35impl Signal for AsyncSignal {
36    fn fire(&self) -> bool {
37        self.woken.store(true, Ordering::SeqCst);
38        self.waker.lock().wake_by_ref();
39        self.stream
40    }
41
42    fn as_any(&self) -> &(dyn Any + 'static) {
43        self
44    }
45    fn as_ptr(&self) -> *const () {
46        self as *const _ as *const ()
47    }
48}
49
50impl<T> Hook<T, AsyncSignal> {
51    // Update the hook to point to the given Waker.
52    // Returns whether the hook has been previously awakened
53    fn update_waker(&self, cx_waker: &Waker) -> bool {
54        let mut waker = self.1.waker.lock();
55        let woken = self.1.woken.load(Ordering::SeqCst);
56        if !waker.will_wake(cx_waker) {
57            *waker = cx_waker.clone();
58
59            // Avoid the edge case where the waker was woken just before the wakers were
60            // swapped.
61            if woken {
62                cx_waker.wake_by_ref();
63            }
64        }
65        woken
66    }
67}
68
69#[derive(Clone)]
70enum OwnedOrRef<'a, T> {
71    Owned(T),
72    Ref(&'a T),
73}
74
75impl<'a, T> Deref for OwnedOrRef<'a, T> {
76    type Target = T;
77
78    fn deref(&self) -> &T {
79        match self {
80            OwnedOrRef::Owned(arc) => arc,
81            OwnedOrRef::Ref(r) => r,
82        }
83    }
84}
85
86impl<T> Sender<T> {
87    /// Asynchronously send a value into the channel, returning an error if all receivers have been
88    /// dropped. If the channel is bounded and is full, the returned future will yield to the async
89    /// runtime.
90    ///
91    /// In the current implementation, the returned future will not yield to the async runtime if the
92    /// channel is unbounded. This may change in later versions.
93    pub fn send_async(&self, item: T) -> SendFut<'_, T> {
94        SendFut {
95            sender: OwnedOrRef::Ref(self),
96            hook: Some(SendState::NotYetSent(item)),
97        }
98    }
99
100    /// Convert this sender into a future that asynchronously sends a single message into the channel,
101    /// returning an error if all receivers have been dropped. If the channel is bounded and is full,
102    /// this future will yield to the async runtime.
103    ///
104    /// In the current implementation, the returned future will not yield to the async runtime if the
105    /// channel is unbounded. This may change in later versions.
106    pub fn into_send_async<'a>(self, item: T) -> SendFut<'a, T> {
107        SendFut {
108            sender: OwnedOrRef::Owned(self),
109            hook: Some(SendState::NotYetSent(item)),
110        }
111    }
112
113    /// Create an asynchronous sink that uses this sender to asynchronously send messages into the
114    /// channel. The sender will continue to be usable after the sink has been dropped.
115    ///
116    /// In the current implementation, the returned sink will not yield to the async runtime if the
117    /// channel is unbounded. This may change in later versions.
118    pub fn sink(&self) -> SendSink<'_, T> {
119        SendSink(SendFut {
120            sender: OwnedOrRef::Ref(self),
121            hook: None,
122        })
123    }
124
125    /// Convert this sender into a sink that allows asynchronously sending messages into the channel.
126    ///
127    /// In the current implementation, the returned sink will not yield to the async runtime if the
128    /// channel is unbounded. This may change in later versions.
129    pub fn into_sink<'a>(self) -> SendSink<'a, T> {
130        SendSink(SendFut {
131            sender: OwnedOrRef::Owned(self),
132            hook: None,
133        })
134    }
135}
136
137enum SendState<T> {
138    NotYetSent(T),
139    QueuedItem(Arc<Hook<T, AsyncSignal>>),
140}
141
142/// A future that sends a value into a channel.
143///
144/// Can be created via [`Sender::send_async`] or [`Sender::into_send_async`].
145#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
146pub struct SendFut<'a, T> {
147    sender: OwnedOrRef<'a, Sender<T>>,
148    // Only none after dropping
149    hook: Option<SendState<T>>,
150}
151
152impl<'a, T> Debug for SendFut<'a, T> {
153    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
154        f.debug_struct("SendFut").finish()
155    }
156}
157
158impl<T> std::marker::Unpin for SendFut<'_, T> {}
159
160impl<'a, T> SendFut<'a, T> {
161    /// Reset the hook, clearing it and removing it from the waiting sender's queue. This is called
162    /// on drop and just before `start_send` in the `Sink` implementation.
163    fn reset_hook(&mut self) {
164        if let Some(SendState::QueuedItem(hook)) = self.hook.take() {
165            let hook: Arc<Hook<T, dyn Signal>> = hook;
166            wait_lock(&self.sender.shared.chan)
167                .sending
168                .as_mut()
169                .unwrap()
170                .1
171                .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
172        }
173    }
174
175    /// See [`Sender::is_disconnected`].
176    pub fn is_disconnected(&self) -> bool {
177        self.sender.is_disconnected()
178    }
179
180    /// See [`Sender::is_empty`].
181    pub fn is_empty(&self) -> bool {
182        self.sender.is_empty()
183    }
184
185    /// See [`Sender::is_full`].
186    pub fn is_full(&self) -> bool {
187        self.sender.is_full()
188    }
189
190    /// See [`Sender::len`].
191    pub fn len(&self) -> usize {
192        self.sender.len()
193    }
194
195    /// See [`Sender::capacity`].
196    pub fn capacity(&self) -> Option<usize> {
197        self.sender.capacity()
198    }
199}
200
201impl<'a, T> Drop for SendFut<'a, T> {
202    fn drop(&mut self) {
203        self.reset_hook()
204    }
205}
206
207impl<'a, T> Future for SendFut<'a, T> {
208    type Output = Result<(), SendError<T>>;
209
210    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
211        if let Some(SendState::QueuedItem(hook)) = self.hook.as_ref() {
212            if hook.is_empty() {
213                Poll::Ready(Ok(()))
214            } else if self.sender.shared.is_disconnected() {
215                let item = hook.try_take();
216                self.hook = None;
217                match item {
218                    Some(item) => Poll::Ready(Err(SendError(item))),
219                    None => Poll::Ready(Ok(())),
220                }
221            } else {
222                hook.update_waker(cx.waker());
223                Poll::Pending
224            }
225        } else if let Some(SendState::NotYetSent(item)) = self.hook.take() {
226            let this = self.get_mut();
227            let (shared, this_hook) = (&this.sender.shared, &mut this.hook);
228
229            shared
230                .send(
231                    // item
232                    item,
233                    // should_block
234                    true,
235                    // make_signal
236                    |msg| Hook::slot(Some(msg), AsyncSignal::new(cx, false)),
237                    // do_block
238                    |hook| {
239                        *this_hook = Some(SendState::QueuedItem(hook));
240                        Poll::Pending
241                    },
242                )
243                .map(|r| {
244                    r.map_err(|err| match err {
245                        TrySendTimeoutError::Disconnected(msg) => SendError(msg),
246                        _ => unreachable!(),
247                    })
248                })
249        } else {
250            // Nothing to do
251            Poll::Ready(Ok(()))
252        }
253    }
254}
255
256impl<'a, T> FusedFuture for SendFut<'a, T> {
257    fn is_terminated(&self) -> bool {
258        self.sender.shared.is_disconnected()
259    }
260}
261
262/// A sink that allows sending values into a channel.
263///
264/// Can be created via [`Sender::sink`] or [`Sender::into_sink`].
265pub struct SendSink<'a, T>(SendFut<'a, T>);
266
267impl<'a, T> SendSink<'a, T> {
268    /// Returns a clone of a sending half of the channel of this sink.
269    pub fn sender(&self) -> &Sender<T> {
270        &self.0.sender
271    }
272
273    /// See [`Sender::is_disconnected`].
274    pub fn is_disconnected(&self) -> bool {
275        self.0.is_disconnected()
276    }
277
278    /// See [`Sender::is_empty`].
279    pub fn is_empty(&self) -> bool {
280        self.0.is_empty()
281    }
282
283    /// See [`Sender::is_full`].
284    pub fn is_full(&self) -> bool {
285        self.0.is_full()
286    }
287
288    /// See [`Sender::len`].
289    pub fn len(&self) -> usize {
290        self.0.len()
291    }
292
293    /// See [`Sender::capacity`].
294    pub fn capacity(&self) -> Option<usize> {
295        self.0.capacity()
296    }
297
298    /// Returns whether the SendSinks are belong to the same channel.
299    pub fn same_channel(&self, other: &Self) -> bool {
300        self.sender().same_channel(other.sender())
301    }
302}
303
304impl<'a, T> Debug for SendSink<'a, T> {
305    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
306        f.debug_struct("SendSink").finish()
307    }
308}
309
310impl<'a, T> Sink<T> for SendSink<'a, T> {
311    type Error = SendError<T>;
312
313    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
314        Pin::new(&mut self.0).poll(cx)
315    }
316
317    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
318        self.0.reset_hook();
319        self.0.hook = Some(SendState::NotYetSent(item));
320
321        Ok(())
322    }
323
324    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
325        Pin::new(&mut self.0).poll(cx) // TODO: A different strategy here?
326    }
327
328    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
329        Pin::new(&mut self.0).poll(cx) // TODO: A different strategy here?
330    }
331}
332
333impl<'a, T> Clone for SendSink<'a, T> {
334    fn clone(&self) -> SendSink<'a, T> {
335        SendSink(SendFut {
336            sender: self.0.sender.clone(),
337            hook: None,
338        })
339    }
340}
341
342impl<T> Receiver<T> {
343    /// Asynchronously receive a value from the channel, returning an error if all senders have been
344    /// dropped. If the channel is empty, the returned future will yield to the async runtime.
345    pub fn recv_async(&self) -> RecvFut<'_, T> {
346        RecvFut::new(OwnedOrRef::Ref(self))
347    }
348
349    /// Convert this receiver into a future that asynchronously receives a single message from the
350    /// channel, returning an error if all senders have been dropped. If the channel is empty, this
351    /// future will yield to the async runtime.
352    pub fn into_recv_async<'a>(self) -> RecvFut<'a, T> {
353        RecvFut::new(OwnedOrRef::Owned(self))
354    }
355
356    /// Create an asynchronous stream that uses this receiver to asynchronously receive messages
357    /// from the channel. The receiver will continue to be usable after the stream has been dropped.
358    pub fn stream(&self) -> RecvStream<'_, T> {
359        RecvStream(RecvFut::new(OwnedOrRef::Ref(self)))
360    }
361
362    /// Convert this receiver into a stream that allows asynchronously receiving messages from the channel.
363    pub fn into_stream<'a>(self) -> RecvStream<'a, T> {
364        RecvStream(RecvFut::new(OwnedOrRef::Owned(self)))
365    }
366}
367
368/// A future which allows asynchronously receiving a message.
369///
370/// Can be created via [`Receiver::recv_async`] or [`Receiver::into_recv_async`].
371#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
372pub struct RecvFut<'a, T> {
373    receiver: OwnedOrRef<'a, Receiver<T>>,
374    hook: Option<Arc<Hook<T, AsyncSignal>>>,
375}
376
377impl<'a, T> RecvFut<'a, T> {
378    fn new(receiver: OwnedOrRef<'a, Receiver<T>>) -> Self {
379        Self {
380            receiver,
381            hook: None,
382        }
383    }
384
385    /// Reset the hook, clearing it and removing it from the waiting receivers queue and waking
386    /// another receiver if this receiver has been woken, so as not to cause any missed wakeups.
387    /// This is called on drop and after a new item is received in `Stream::poll_next`.
388    fn reset_hook(&mut self) {
389        if let Some(hook) = self.hook.take() {
390            let hook: Arc<Hook<T, dyn Signal>> = hook;
391            let mut chan = wait_lock(&self.receiver.shared.chan);
392            // We'd like to use `Arc::ptr_eq` here but it doesn't seem to work consistently with wide pointers?
393            chan.waiting
394                .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
395            if hook
396                .signal()
397                .as_any()
398                .downcast_ref::<AsyncSignal>()
399                .unwrap()
400                .woken
401                .load(Ordering::SeqCst)
402            {
403                // If this signal has been fired, but we're being dropped (and so not listening to it),
404                // pass the signal on to another receiver
405                chan.try_wake_receiver_if_pending();
406            }
407        }
408    }
409
410    fn poll_inner(
411        self: Pin<&mut Self>,
412        cx: &mut Context,
413        stream: bool,
414    ) -> Poll<Result<T, RecvError>> {
415        if self.hook.is_some() {
416            match self.receiver.shared.recv_sync(None) {
417                Ok(msg) => return Poll::Ready(Ok(msg)),
418                Err(TryRecvTimeoutError::Disconnected) => {
419                    return Poll::Ready(Err(RecvError::Disconnected))
420                }
421                _ => (),
422            }
423
424            let hook = self.hook.as_ref().map(Arc::clone).unwrap();
425            if hook.update_waker(cx.waker()) {
426                // If the previous hook was awakened, we need to insert it back to the
427                // queue, otherwise, it remains valid.
428                wait_lock(&self.receiver.shared.chan)
429                    .waiting
430                    .push_back(hook);
431            }
432            // To avoid a missed wakeup, re-check disconnect status here because the channel might have
433            // gotten shut down before we had a chance to push our hook
434            if self.receiver.shared.is_disconnected() {
435                // And now, to avoid a race condition between the first recv attempt and the disconnect check we
436                // just performed, attempt to recv again just in case we missed something.
437                Poll::Ready(
438                    self.receiver
439                        .shared
440                        .recv_sync(None)
441                        .map(Ok)
442                        .unwrap_or(Err(RecvError::Disconnected)),
443                )
444            } else {
445                Poll::Pending
446            }
447        } else {
448            let mut_self = self.get_mut();
449            let (shared, this_hook) = (&mut_self.receiver.shared, &mut mut_self.hook);
450
451            shared
452                .recv(
453                    // should_block
454                    true,
455                    // make_signal
456                    || Hook::trigger(AsyncSignal::new(cx, stream)),
457                    // do_block
458                    |hook| {
459                        *this_hook = Some(hook);
460                        Poll::Pending
461                    },
462                )
463                .map(|r| {
464                    r.map_err(|err| match err {
465                        TryRecvTimeoutError::Disconnected => RecvError::Disconnected,
466                        _ => unreachable!(),
467                    })
468                })
469        }
470    }
471
472    /// See [`Receiver::is_disconnected`].
473    pub fn is_disconnected(&self) -> bool {
474        self.receiver.is_disconnected()
475    }
476
477    /// See [`Receiver::is_empty`].
478    pub fn is_empty(&self) -> bool {
479        self.receiver.is_empty()
480    }
481
482    /// See [`Receiver::is_full`].
483    pub fn is_full(&self) -> bool {
484        self.receiver.is_full()
485    }
486
487    /// See [`Receiver::len`].
488    pub fn len(&self) -> usize {
489        self.receiver.len()
490    }
491
492    /// See [`Receiver::capacity`].
493    pub fn capacity(&self) -> Option<usize> {
494        self.receiver.capacity()
495    }
496}
497
498impl<'a, T> Debug for RecvFut<'a, T> {
499    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
500        f.debug_struct("RecvFut").finish()
501    }
502}
503
504impl<'a, T> Drop for RecvFut<'a, T> {
505    fn drop(&mut self) {
506        self.reset_hook();
507    }
508}
509
510impl<'a, T> Future for RecvFut<'a, T> {
511    type Output = Result<T, RecvError>;
512
513    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
514        self.poll_inner(cx, false) // stream = false
515    }
516}
517
518impl<'a, T> FusedFuture for RecvFut<'a, T> {
519    fn is_terminated(&self) -> bool {
520        self.receiver.shared.is_disconnected() && self.receiver.shared.is_empty()
521    }
522}
523
524/// A stream which allows asynchronously receiving messages.
525///
526/// Can be created via [`Receiver::stream`] or [`Receiver::into_stream`].
527pub struct RecvStream<'a, T>(RecvFut<'a, T>);
528
529impl<'a, T> RecvStream<'a, T> {
530    /// See [`Receiver::is_disconnected`].
531    pub fn is_disconnected(&self) -> bool {
532        self.0.is_disconnected()
533    }
534
535    /// See [`Receiver::is_empty`].
536    pub fn is_empty(&self) -> bool {
537        self.0.is_empty()
538    }
539
540    /// See [`Receiver::is_full`].
541    pub fn is_full(&self) -> bool {
542        self.0.is_full()
543    }
544
545    /// See [`Receiver::len`].
546    pub fn len(&self) -> usize {
547        self.0.len()
548    }
549
550    /// See [`Receiver::capacity`].
551    pub fn capacity(&self) -> Option<usize> {
552        self.0.capacity()
553    }
554
555    /// Returns whether the SendSinks are belong to the same channel.
556    pub fn same_channel(&self, other: &Self) -> bool {
557        self.0.receiver.same_channel(&*other.0.receiver)
558    }
559}
560
561impl<'a, T> Debug for RecvStream<'a, T> {
562    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
563        f.debug_struct("RecvStream").finish()
564    }
565}
566
567impl<'a, T> Clone for RecvStream<'a, T> {
568    fn clone(&self) -> RecvStream<'a, T> {
569        RecvStream(RecvFut::new(self.0.receiver.clone()))
570    }
571}
572
573impl<'a, T> Stream for RecvStream<'a, T> {
574    type Item = T;
575
576    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
577        match Pin::new(&mut self.0).poll_inner(cx, true) {
578            // stream = true
579            Poll::Pending => Poll::Pending,
580            Poll::Ready(item) => {
581                self.0.reset_hook();
582                Poll::Ready(item.ok())
583            }
584        }
585    }
586}
587
588impl<'a, T> FusedStream for RecvStream<'a, T> {
589    fn is_terminated(&self) -> bool {
590        self.0.is_terminated()
591    }
592}