flume/
select.rs

1//! Types that permit waiting upon multiple blocking operations using the [`Selector`] interface.
2
3use crate::*;
4use spin1::Mutex as Spinlock;
5use std::{any::Any, marker::PhantomData};
6
7// A unique token corresponding to an event in a selector
8type Token = usize;
9
10struct SelectSignal(
11    thread::Thread,
12    Token,
13    AtomicBool,
14    Arc<Spinlock<VecDeque<Token>>>,
15);
16
17impl Signal for SelectSignal {
18    fn fire(&self) -> bool {
19        self.2.store(true, Ordering::SeqCst);
20        self.3.lock().push_back(self.1);
21        self.0.unpark();
22        false
23    }
24
25    fn as_any(&self) -> &(dyn Any + 'static) {
26        self
27    }
28    fn as_ptr(&self) -> *const () {
29        self as *const _ as *const ()
30    }
31}
32
33trait Selection<'a, T> {
34    fn init(&mut self) -> Option<T>;
35    fn poll(&mut self) -> Option<T>;
36    fn deinit(&mut self);
37}
38
39/// An error that may be emitted when attempting to wait for a value on a receiver.
40#[derive(Copy, Clone, Debug, PartialEq, Eq)]
41pub enum SelectError {
42    /// A timeout occurred when waiting on a `Selector`.
43    Timeout,
44}
45
46impl fmt::Display for SelectError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            SelectError::Timeout => "timeout occurred".fmt(f),
50        }
51    }
52}
53
54impl std::error::Error for SelectError {}
55
56/// A type used to wait upon multiple blocking operations at once.
57///
58/// A [`Selector`] implements [`select`](https://en.wikipedia.org/wiki/Select_(Unix))-like behaviour,
59/// allowing a thread to wait upon the result of more than one operation at once.
60///
61/// # Examples
62/// ```
63/// let (tx0, rx0) = flume::unbounded();
64/// let (tx1, rx1) = flume::unbounded();
65///
66/// std::thread::spawn(move || {
67///     tx0.send(true).unwrap();
68///     tx1.send(42).unwrap();
69/// });
70///
71/// flume::Selector::new()
72///     .recv(&rx0, |b| println!("Received {:?}", b))
73///     .recv(&rx1, |n| println!("Received {:?}", n))
74///     .wait();
75/// ```
76pub struct Selector<'a, T: 'a> {
77    selections: Vec<Box<dyn Selection<'a, T> + 'a>>,
78    next_poll: usize,
79    signalled: Arc<Spinlock<VecDeque<Token>>>,
80    #[cfg(feature = "eventual-fairness")]
81    rng: fastrand::Rng,
82    phantom: PhantomData<*const ()>,
83}
84
85impl<'a, T: 'a> Default for Selector<'a, T> {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl<'a, T: 'a> fmt::Debug for Selector<'a, T> {
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        f.debug_struct("Selector").finish()
94    }
95}
96
97impl<'a, T> Selector<'a, T> {
98    /// Create a new selector.
99    pub fn new() -> Self {
100        Self {
101            selections: Vec::new(),
102            next_poll: 0,
103            signalled: Arc::default(),
104            phantom: PhantomData,
105            #[cfg(feature = "eventual-fairness")]
106            rng: fastrand::Rng::new(),
107        }
108    }
109
110    /// Add a send operation to the selector that sends the provided value.
111    ///
112    /// Once added, the selector can be used to run the provided handler function on completion of this operation.
113    pub fn send<U, F: FnMut(Result<(), SendError<U>>) -> T + 'a>(
114        mut self,
115        sender: &'a Sender<U>,
116        msg: U,
117        mapper: F,
118    ) -> Self {
119        struct SendSelection<'a, T, F, U> {
120            sender: &'a Sender<U>,
121            msg: Option<U>,
122            token: Token,
123            signalled: Arc<Spinlock<VecDeque<Token>>>,
124            hook: Option<Arc<Hook<U, SelectSignal>>>,
125            mapper: F,
126            phantom: PhantomData<T>,
127        }
128
129        impl<'a, T, F, U> Selection<'a, T> for SendSelection<'a, T, F, U>
130        where
131            F: FnMut(Result<(), SendError<U>>) -> T,
132        {
133            fn init(&mut self) -> Option<T> {
134                let token = self.token;
135                let signalled = self.signalled.clone();
136                let r = self.sender.shared.send(
137                    self.msg.take().unwrap(),
138                    true,
139                    |msg| {
140                        Hook::slot(
141                            Some(msg),
142                            SelectSignal(
143                                thread::current(),
144                                token,
145                                AtomicBool::new(false),
146                                signalled,
147                            ),
148                        )
149                    },
150                    // Always runs
151                    |h| {
152                        self.hook = Some(h);
153                        Ok(())
154                    },
155                );
156
157                if self.hook.is_none() {
158                    Some((self.mapper)(match r {
159                        Ok(()) => Ok(()),
160                        Err(TrySendTimeoutError::Disconnected(msg)) => Err(SendError(msg)),
161                        _ => unreachable!(),
162                    }))
163                } else {
164                    None
165                }
166            }
167
168            fn poll(&mut self) -> Option<T> {
169                let res = if self.sender.shared.is_disconnected() {
170                    // Check the hook one last time
171                    if let Some(msg) = self.hook.as_ref()?.try_take() {
172                        Err(SendError(msg))
173                    } else {
174                        Ok(())
175                    }
176                } else if self.hook.as_ref().unwrap().is_empty() {
177                    // The message was sent
178                    Ok(())
179                } else {
180                    return None;
181                };
182
183                Some((self.mapper)(res))
184            }
185
186            fn deinit(&mut self) {
187                if let Some(hook) = self.hook.take() {
188                    // Remove hook
189                    let hook: Arc<Hook<U, dyn Signal>> = hook;
190                    wait_lock(&self.sender.shared.chan)
191                        .sending
192                        .as_mut()
193                        .unwrap()
194                        .1
195                        .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
196                }
197            }
198        }
199
200        let token = self.selections.len();
201        self.selections.push(Box::new(SendSelection {
202            sender,
203            msg: Some(msg),
204            token,
205            signalled: self.signalled.clone(),
206            hook: None,
207            mapper,
208            phantom: Default::default(),
209        }));
210
211        self
212    }
213
214    /// Add a receive operation to the selector.
215    ///
216    /// Once added, the selector can be used to run the provided handler function on completion of this operation.
217    pub fn recv<U, F: FnMut(Result<U, RecvError>) -> T + 'a>(
218        mut self,
219        receiver: &'a Receiver<U>,
220        mapper: F,
221    ) -> Self {
222        struct RecvSelection<'a, T, F, U> {
223            receiver: &'a Receiver<U>,
224            token: Token,
225            signalled: Arc<Spinlock<VecDeque<Token>>>,
226            hook: Option<Arc<Hook<U, SelectSignal>>>,
227            mapper: F,
228            received: bool,
229            phantom: PhantomData<T>,
230        }
231
232        impl<'a, T, F, U> Selection<'a, T> for RecvSelection<'a, T, F, U>
233        where
234            F: FnMut(Result<U, RecvError>) -> T,
235        {
236            fn init(&mut self) -> Option<T> {
237                let token = self.token;
238                let signalled = self.signalled.clone();
239                let r = self.receiver.shared.recv(
240                    true,
241                    || {
242                        Hook::trigger(SelectSignal(
243                            thread::current(),
244                            token,
245                            AtomicBool::new(false),
246                            signalled,
247                        ))
248                    },
249                    // Always runs
250                    |h| {
251                        self.hook = Some(h);
252                        Err(TryRecvTimeoutError::Timeout)
253                    },
254                );
255
256                if self.hook.is_none() {
257                    Some((self.mapper)(match r {
258                        Ok(msg) => Ok(msg),
259                        Err(TryRecvTimeoutError::Disconnected) => Err(RecvError::Disconnected),
260                        _ => unreachable!(),
261                    }))
262                } else {
263                    None
264                }
265            }
266
267            fn poll(&mut self) -> Option<T> {
268                let res = if let Ok(msg) = self.receiver.try_recv() {
269                    self.received = true;
270                    Ok(msg)
271                } else if self.receiver.shared.is_disconnected() {
272                    Err(RecvError::Disconnected)
273                } else {
274                    return None;
275                };
276
277                Some((self.mapper)(res))
278            }
279
280            fn deinit(&mut self) {
281                if let Some(hook) = self.hook.take() {
282                    // Remove hook
283                    let hook: Arc<Hook<U, dyn Signal>> = hook;
284                    wait_lock(&self.receiver.shared.chan)
285                        .waiting
286                        .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
287                    // If we were woken, but never polled, wake up another
288                    if !self.received
289                        && hook
290                            .signal()
291                            .as_any()
292                            .downcast_ref::<SelectSignal>()
293                            .unwrap()
294                            .2
295                            .load(Ordering::SeqCst)
296                    {
297                        wait_lock(&self.receiver.shared.chan).try_wake_receiver_if_pending();
298                    }
299                }
300            }
301        }
302
303        let token = self.selections.len();
304        self.selections.push(Box::new(RecvSelection {
305            receiver,
306            token,
307            signalled: self.signalled.clone(),
308            hook: None,
309            mapper,
310            received: false,
311            phantom: Default::default(),
312        }));
313
314        self
315    }
316
317    fn wait_inner(mut self, deadline: Option<Instant>) -> Option<T> {
318        #[cfg(feature = "eventual-fairness")]
319        {
320            self.next_poll = self.rng.usize(0..self.selections.len());
321        }
322
323        let res = 'outer: {
324            // Init signals
325            for _ in 0..self.selections.len() {
326                if let Some(val) = self.selections[self.next_poll].init() {
327                    break 'outer Some(val);
328                }
329                self.next_poll = (self.next_poll + 1) % self.selections.len();
330            }
331
332            // Speculatively poll
333            if let Some(msg) = self.poll() {
334                break 'outer Some(msg);
335            }
336
337            loop {
338                if let Some(deadline) = deadline {
339                    if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
340                        thread::park_timeout(dur);
341                    }
342                } else {
343                    thread::park();
344                }
345
346                if deadline.map(|d| Instant::now() >= d).unwrap_or(false) {
347                    break 'outer self.poll();
348                }
349
350                let token = if let Some(token) = self.signalled.lock().pop_front() {
351                    token
352                } else {
353                    // Spurious wakeup, park again
354                    continue;
355                };
356
357                // Attempt to receive a message
358                if let Some(msg) = self.selections[token].poll() {
359                    break 'outer Some(msg);
360                }
361            }
362        };
363
364        // Deinit signals
365        for s in &mut self.selections {
366            s.deinit();
367        }
368
369        res
370    }
371
372    fn poll(&mut self) -> Option<T> {
373        for _ in 0..self.selections.len() {
374            if let Some(val) = self.selections[self.next_poll].poll() {
375                return Some(val);
376            }
377            self.next_poll = (self.next_poll + 1) % self.selections.len();
378        }
379        None
380    }
381
382    /// Wait until one of the events associated with this [`Selector`] has completed. If the `eventual-fairness`
383    /// feature flag is enabled, this method is fair and will handle a random event of those that are ready.
384    pub fn wait(self) -> T {
385        self.wait_inner(None).unwrap()
386    }
387
388    /// Wait until one of the events associated with this [`Selector`] has completed or the timeout has expired. If the
389    /// `eventual-fairness` feature flag is enabled, this method is fair and will handle a random event of those that
390    /// are ready.
391    pub fn wait_timeout(self, dur: Duration) -> Result<T, SelectError> {
392        self.wait_inner(Instant::now().checked_add(dur))
393            .ok_or(SelectError::Timeout)
394    }
395
396    /// Wait until one of the events associated with this [`Selector`] has completed or the deadline has been reached.
397    /// If the `eventual-fairness` feature flag is enabled, this method is fair and will handle a random event of those
398    /// that are ready.
399    pub fn wait_deadline(self, deadline: Instant) -> Result<T, SelectError> {
400        self.wait_inner(Some(deadline)).ok_or(SelectError::Timeout)
401    }
402}