1use crate::*;
4use spin1::Mutex as Spinlock;
5use std::{any::Any, marker::PhantomData};
6
7type Token = usize;
9
10struct SelectSignal(
11 thread::Thread,
12 Token,
13 AtomicBool,
14 Arc<Spinlock<VecDeque<Token>>>,
15);
16
17impl Signal for SelectSignal {
18 fn fire(&self) -> bool {
19 self.2.store(true, Ordering::SeqCst);
20 self.3.lock().push_back(self.1);
21 self.0.unpark();
22 false
23 }
24
25 fn as_any(&self) -> &(dyn Any + 'static) {
26 self
27 }
28 fn as_ptr(&self) -> *const () {
29 self as *const _ as *const ()
30 }
31}
32
33trait Selection<'a, T> {
34 fn init(&mut self) -> Option<T>;
35 fn poll(&mut self) -> Option<T>;
36 fn deinit(&mut self);
37}
38
39#[derive(Copy, Clone, Debug, PartialEq, Eq)]
41pub enum SelectError {
42 Timeout,
44}
45
46impl fmt::Display for SelectError {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 SelectError::Timeout => "timeout occurred".fmt(f),
50 }
51 }
52}
53
54impl std::error::Error for SelectError {}
55
56pub struct Selector<'a, T: 'a> {
77 selections: Vec<Box<dyn Selection<'a, T> + 'a>>,
78 next_poll: usize,
79 signalled: Arc<Spinlock<VecDeque<Token>>>,
80 #[cfg(feature = "eventual-fairness")]
81 rng: fastrand::Rng,
82 phantom: PhantomData<*const ()>,
83}
84
85impl<'a, T: 'a> Default for Selector<'a, T> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl<'a, T: 'a> fmt::Debug for Selector<'a, T> {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 f.debug_struct("Selector").finish()
94 }
95}
96
97impl<'a, T> Selector<'a, T> {
98 pub fn new() -> Self {
100 Self {
101 selections: Vec::new(),
102 next_poll: 0,
103 signalled: Arc::default(),
104 phantom: PhantomData,
105 #[cfg(feature = "eventual-fairness")]
106 rng: fastrand::Rng::new(),
107 }
108 }
109
110 pub fn send<U, F: FnMut(Result<(), SendError<U>>) -> T + 'a>(
114 mut self,
115 sender: &'a Sender<U>,
116 msg: U,
117 mapper: F,
118 ) -> Self {
119 struct SendSelection<'a, T, F, U> {
120 sender: &'a Sender<U>,
121 msg: Option<U>,
122 token: Token,
123 signalled: Arc<Spinlock<VecDeque<Token>>>,
124 hook: Option<Arc<Hook<U, SelectSignal>>>,
125 mapper: F,
126 phantom: PhantomData<T>,
127 }
128
129 impl<'a, T, F, U> Selection<'a, T> for SendSelection<'a, T, F, U>
130 where
131 F: FnMut(Result<(), SendError<U>>) -> T,
132 {
133 fn init(&mut self) -> Option<T> {
134 let token = self.token;
135 let signalled = self.signalled.clone();
136 let r = self.sender.shared.send(
137 self.msg.take().unwrap(),
138 true,
139 |msg| {
140 Hook::slot(
141 Some(msg),
142 SelectSignal(
143 thread::current(),
144 token,
145 AtomicBool::new(false),
146 signalled,
147 ),
148 )
149 },
150 |h| {
152 self.hook = Some(h);
153 Ok(())
154 },
155 );
156
157 if self.hook.is_none() {
158 Some((self.mapper)(match r {
159 Ok(()) => Ok(()),
160 Err(TrySendTimeoutError::Disconnected(msg)) => Err(SendError(msg)),
161 _ => unreachable!(),
162 }))
163 } else {
164 None
165 }
166 }
167
168 fn poll(&mut self) -> Option<T> {
169 let res = if self.sender.shared.is_disconnected() {
170 if let Some(msg) = self.hook.as_ref()?.try_take() {
172 Err(SendError(msg))
173 } else {
174 Ok(())
175 }
176 } else if self.hook.as_ref().unwrap().is_empty() {
177 Ok(())
179 } else {
180 return None;
181 };
182
183 Some((self.mapper)(res))
184 }
185
186 fn deinit(&mut self) {
187 if let Some(hook) = self.hook.take() {
188 let hook: Arc<Hook<U, dyn Signal>> = hook;
190 wait_lock(&self.sender.shared.chan)
191 .sending
192 .as_mut()
193 .unwrap()
194 .1
195 .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
196 }
197 }
198 }
199
200 let token = self.selections.len();
201 self.selections.push(Box::new(SendSelection {
202 sender,
203 msg: Some(msg),
204 token,
205 signalled: self.signalled.clone(),
206 hook: None,
207 mapper,
208 phantom: Default::default(),
209 }));
210
211 self
212 }
213
214 pub fn recv<U, F: FnMut(Result<U, RecvError>) -> T + 'a>(
218 mut self,
219 receiver: &'a Receiver<U>,
220 mapper: F,
221 ) -> Self {
222 struct RecvSelection<'a, T, F, U> {
223 receiver: &'a Receiver<U>,
224 token: Token,
225 signalled: Arc<Spinlock<VecDeque<Token>>>,
226 hook: Option<Arc<Hook<U, SelectSignal>>>,
227 mapper: F,
228 received: bool,
229 phantom: PhantomData<T>,
230 }
231
232 impl<'a, T, F, U> Selection<'a, T> for RecvSelection<'a, T, F, U>
233 where
234 F: FnMut(Result<U, RecvError>) -> T,
235 {
236 fn init(&mut self) -> Option<T> {
237 let token = self.token;
238 let signalled = self.signalled.clone();
239 let r = self.receiver.shared.recv(
240 true,
241 || {
242 Hook::trigger(SelectSignal(
243 thread::current(),
244 token,
245 AtomicBool::new(false),
246 signalled,
247 ))
248 },
249 |h| {
251 self.hook = Some(h);
252 Err(TryRecvTimeoutError::Timeout)
253 },
254 );
255
256 if self.hook.is_none() {
257 Some((self.mapper)(match r {
258 Ok(msg) => Ok(msg),
259 Err(TryRecvTimeoutError::Disconnected) => Err(RecvError::Disconnected),
260 _ => unreachable!(),
261 }))
262 } else {
263 None
264 }
265 }
266
267 fn poll(&mut self) -> Option<T> {
268 let res = if let Ok(msg) = self.receiver.try_recv() {
269 self.received = true;
270 Ok(msg)
271 } else if self.receiver.shared.is_disconnected() {
272 Err(RecvError::Disconnected)
273 } else {
274 return None;
275 };
276
277 Some((self.mapper)(res))
278 }
279
280 fn deinit(&mut self) {
281 if let Some(hook) = self.hook.take() {
282 let hook: Arc<Hook<U, dyn Signal>> = hook;
284 wait_lock(&self.receiver.shared.chan)
285 .waiting
286 .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
287 if !self.received
289 && hook
290 .signal()
291 .as_any()
292 .downcast_ref::<SelectSignal>()
293 .unwrap()
294 .2
295 .load(Ordering::SeqCst)
296 {
297 wait_lock(&self.receiver.shared.chan).try_wake_receiver_if_pending();
298 }
299 }
300 }
301 }
302
303 let token = self.selections.len();
304 self.selections.push(Box::new(RecvSelection {
305 receiver,
306 token,
307 signalled: self.signalled.clone(),
308 hook: None,
309 mapper,
310 received: false,
311 phantom: Default::default(),
312 }));
313
314 self
315 }
316
317 fn wait_inner(mut self, deadline: Option<Instant>) -> Option<T> {
318 #[cfg(feature = "eventual-fairness")]
319 {
320 self.next_poll = self.rng.usize(0..self.selections.len());
321 }
322
323 let res = 'outer: {
324 for _ in 0..self.selections.len() {
326 if let Some(val) = self.selections[self.next_poll].init() {
327 break 'outer Some(val);
328 }
329 self.next_poll = (self.next_poll + 1) % self.selections.len();
330 }
331
332 if let Some(msg) = self.poll() {
334 break 'outer Some(msg);
335 }
336
337 loop {
338 if let Some(deadline) = deadline {
339 if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
340 thread::park_timeout(dur);
341 }
342 } else {
343 thread::park();
344 }
345
346 if deadline.map(|d| Instant::now() >= d).unwrap_or(false) {
347 break 'outer self.poll();
348 }
349
350 let token = if let Some(token) = self.signalled.lock().pop_front() {
351 token
352 } else {
353 continue;
355 };
356
357 if let Some(msg) = self.selections[token].poll() {
359 break 'outer Some(msg);
360 }
361 }
362 };
363
364 for s in &mut self.selections {
366 s.deinit();
367 }
368
369 res
370 }
371
372 fn poll(&mut self) -> Option<T> {
373 for _ in 0..self.selections.len() {
374 if let Some(val) = self.selections[self.next_poll].poll() {
375 return Some(val);
376 }
377 self.next_poll = (self.next_poll + 1) % self.selections.len();
378 }
379 None
380 }
381
382 pub fn wait(self) -> T {
385 self.wait_inner(None).unwrap()
386 }
387
388 pub fn wait_timeout(self, dur: Duration) -> Result<T, SelectError> {
392 self.wait_inner(Instant::now().checked_add(dur))
393 .ok_or(SelectError::Timeout)
394 }
395
396 pub fn wait_deadline(self, deadline: Instant) -> Result<T, SelectError> {
400 self.wait_inner(Some(deadline)).ok_or(SelectError::Timeout)
401 }
402}