rtic_sync/
channel.rs

1//! An async aware MPSC channel that can be used on no-alloc systems.
2
3use crate::unsafecell::UnsafeCell;
4use core::{
5    future::poll_fn,
6    mem::MaybeUninit,
7    pin::Pin,
8    ptr,
9    sync::atomic::{fence, Ordering},
10    task::{Poll, Waker},
11};
12#[doc(hidden)]
13pub use critical_section;
14use heapless::Deque;
15use rtic_common::{
16    dropper::OnDrop, wait_queue::DoublyLinkedList, wait_queue::Link,
17    waker_registration::CriticalSectionWakerRegistration as WakerRegistration,
18};
19
20#[cfg(feature = "defmt-03")]
21use crate::defmt;
22
23type WaitQueueData = (Waker, SlotPtr);
24type WaitQueue = DoublyLinkedList<WaitQueueData>;
25
26/// An MPSC channel for use in no-alloc systems. `N` sets the size of the queue.
27///
28/// This channel uses critical sections, however there are extremely small and all `memcpy`
29/// operations of `T` are done without critical sections.
30pub struct Channel<T, const N: usize> {
31    // Here are all indexes that are not used in `slots` and ready to be allocated.
32    freeq: UnsafeCell<Deque<u8, N>>,
33    // Here are wakers and indexes to slots that are ready to be dequeued by the receiver.
34    readyq: UnsafeCell<Deque<u8, N>>,
35    // Waker for the receiver.
36    receiver_waker: WakerRegistration,
37    // Storage for N `T`s, so we don't memcpy around a lot of `T`s.
38    slots: [UnsafeCell<MaybeUninit<T>>; N],
39    // If there is no room in the queue a `Sender`s can wait for there to be place in the queue.
40    wait_queue: WaitQueue,
41    // Keep track of the receiver.
42    receiver_dropped: UnsafeCell<bool>,
43    // Keep track of the number of senders.
44    num_senders: UnsafeCell<usize>,
45}
46
47unsafe impl<T, const N: usize> Send for Channel<T, N> {}
48
49unsafe impl<T, const N: usize> Sync for Channel<T, N> {}
50
51macro_rules! cs_access {
52    ($name:ident, $type:ty) => {
53        /// Access the value mutably.
54        ///
55        /// SAFETY: this function must not be called recursively within `f`.
56        unsafe fn $name<F, R>(&self, _cs: critical_section::CriticalSection, f: F) -> R
57        where
58            F: FnOnce(&mut $type) -> R,
59        {
60            let v = self.$name.get_mut();
61            // SAFETY: we have exclusive access due to the critical section.
62            let v = unsafe { v.deref() };
63            f(v)
64        }
65    };
66}
67
68impl<T, const N: usize> Default for Channel<T, N> {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl<T, const N: usize> Channel<T, N> {
75    const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries");
76
77    /// Create a new channel.
78    #[cfg(not(loom))]
79    pub const fn new() -> Self {
80        Self {
81            freeq: UnsafeCell::new(Deque::new()),
82            readyq: UnsafeCell::new(Deque::new()),
83            receiver_waker: WakerRegistration::new(),
84            slots: [const { UnsafeCell::new(MaybeUninit::uninit()) }; N],
85            wait_queue: WaitQueue::new(),
86            receiver_dropped: UnsafeCell::new(false),
87            num_senders: UnsafeCell::new(0),
88        }
89    }
90
91    /// Create a new channel.
92    #[cfg(loom)]
93    pub fn new() -> Self {
94        Self {
95            freeq: UnsafeCell::new(Deque::new()),
96            readyq: UnsafeCell::new(Deque::new()),
97            receiver_waker: WakerRegistration::new(),
98            slots: core::array::from_fn(|_| UnsafeCell::new(MaybeUninit::uninit())),
99            wait_queue: WaitQueue::new(),
100            receiver_dropped: UnsafeCell::new(false),
101            num_senders: UnsafeCell::new(0),
102        }
103    }
104
105    /// Split the queue into a `Sender`/`Receiver` pair.
106    pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) {
107        // NOTE(assert): queue is cleared by dropping the corresponding `Receiver`.
108        debug_assert!(self.readyq.as_mut().is_empty(),);
109
110        let freeq = self.freeq.as_mut();
111
112        freeq.clear();
113
114        // Fill free queue
115        for idx in 0..N as u8 {
116            debug_assert!(!freeq.is_full());
117
118            // SAFETY: This safe as the loop goes from 0 to the capacity of the underlying queue,
119            // and the queue is cleared beforehand.
120            unsafe {
121                freeq.push_back_unchecked(idx);
122            }
123        }
124
125        debug_assert!(freeq.is_full());
126
127        // There is now 1 sender
128        *self.num_senders.as_mut() = 1;
129
130        (Sender(self), Receiver(self))
131    }
132
133    cs_access!(freeq, Deque<u8, N>);
134    cs_access!(readyq, Deque<u8, N>);
135    cs_access!(receiver_dropped, bool);
136    cs_access!(num_senders, usize);
137
138    /// Return free slot `slot` to the channel.
139    ///
140    /// This will do one of two things:
141    /// 1. If there are any waiting `send`-ers, wake the longest-waiting one and hand it `slot`.
142    /// 2. else, insert `slot` into `self.freeq`.
143    ///
144    /// SAFETY: `slot` must be a `u8` that is obtained by dequeueing from [`Self::readyq`], and that `slot`
145    /// is returned at most once.
146    unsafe fn return_free_slot(&self, slot: u8) {
147        critical_section::with(|cs| {
148            fence(Ordering::SeqCst);
149
150            // If someone is waiting in the `wait_queue`, wake the first one up & hand it the free slot.
151            if let Some((wait_head, mut freeq_slot)) = self.wait_queue.pop() {
152                // SAFETY: `freeq_slot` is valid for writes: we are in a critical
153                // section & the `SlotPtr` lives for at least the duration of the wait queue link.
154                unsafe { freeq_slot.replace(Some(slot), cs) };
155                wait_head.wake();
156            } else {
157                // SAFETY: `self.freeq` is not called recursively.
158                unsafe {
159                    self.freeq(cs, |freeq| {
160                        debug_assert!(!freeq.is_full());
161                        // SAFETY: `freeq` is not full.
162                        freeq.push_back_unchecked(slot);
163                    });
164                }
165            }
166        })
167    }
168
169    /// SAFETY: the caller must guarantee that `slot` is an `u8` obtained by dequeueing from [`Self::readyq`],
170    /// and is read at most once.
171    unsafe fn read_slot(&self, slot: u8) -> T {
172        let first_element = self.slots.get_unchecked(slot as usize).get_mut();
173        let ptr = first_element.deref().as_ptr();
174        ptr::read(ptr)
175    }
176}
177
178/// Creates a split channel with `'static` lifetime.
179#[macro_export]
180#[cfg(not(loom))]
181macro_rules! make_channel {
182    ($type:ty, $size:expr) => {{
183        static mut CHANNEL: $crate::channel::Channel<$type, $size> =
184            $crate::channel::Channel::new();
185
186        static CHECK: $crate::portable_atomic::AtomicU8 = $crate::portable_atomic::AtomicU8::new(0);
187
188        $crate::channel::critical_section::with(|_| {
189            if CHECK.load(::core::sync::atomic::Ordering::Relaxed) != 0 {
190                panic!("call to the same `make_channel` instance twice");
191            }
192
193            CHECK.store(1, ::core::sync::atomic::Ordering::Relaxed);
194        });
195
196        // SAFETY: This is safe as we hide the static mut from others to access it.
197        // Only this point is where the mutable access happens.
198        #[allow(static_mut_refs)]
199        unsafe {
200            CHANNEL.split()
201        }
202    }};
203}
204
205// -------- Sender
206
207/// Error state for when the receiver has been dropped.
208#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
209pub struct NoReceiver<T>(pub T);
210
211/// Errors that 'try_send` can have.
212#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
213pub enum TrySendError<T> {
214    /// Error state for when the receiver has been dropped.
215    NoReceiver(T),
216    /// Error state when the queue is full.
217    Full(T),
218}
219
220impl<T> core::fmt::Debug for NoReceiver<T>
221where
222    T: core::fmt::Debug,
223{
224    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
225        write!(f, "NoReceiver({:?})", self.0)
226    }
227}
228
229impl<T> core::fmt::Debug for TrySendError<T>
230where
231    T: core::fmt::Debug,
232{
233    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234        match self {
235            TrySendError::NoReceiver(v) => write!(f, "NoReceiver({v:?})"),
236            TrySendError::Full(v) => write!(f, "Full({v:?})"),
237        }
238    }
239}
240
241impl<T> PartialEq for TrySendError<T>
242where
243    T: PartialEq,
244{
245    fn eq(&self, other: &Self) -> bool {
246        match (self, other) {
247            (TrySendError::NoReceiver(v1), TrySendError::NoReceiver(v2)) => v1.eq(v2),
248            (TrySendError::NoReceiver(_), TrySendError::Full(_)) => false,
249            (TrySendError::Full(_), TrySendError::NoReceiver(_)) => false,
250            (TrySendError::Full(v1), TrySendError::Full(v2)) => v1.eq(v2),
251        }
252    }
253}
254
255/// A `Sender` can send to the channel and can be cloned.
256pub struct Sender<'a, T, const N: usize>(&'a Channel<T, N>);
257
258unsafe impl<T, const N: usize> Send for Sender<'_, T, N> {}
259
260/// This is needed to make the async closure in `send` accept that we "share"
261/// the link possible between threads.
262#[derive(Clone)]
263struct LinkPtr(*mut Option<Link<WaitQueueData>>);
264
265impl LinkPtr {
266    /// This will dereference the pointer stored within and give out an `&mut`.
267    unsafe fn get(&mut self) -> &mut Option<Link<WaitQueueData>> {
268        &mut *self.0
269    }
270}
271
272unsafe impl Send for LinkPtr {}
273
274unsafe impl Sync for LinkPtr {}
275
276/// This is needed to make the async closure in `send` accept that we "share"
277/// the link possible between threads.
278#[derive(Clone)]
279struct SlotPtr(*mut Option<u8>);
280
281impl SlotPtr {
282    /// Replace the value of this slot with `new_value`, and return
283    /// the old value.
284    ///
285    /// SAFETY: the pointer in this `SlotPtr` must be valid for writes.
286    unsafe fn replace(
287        &mut self,
288        new_value: Option<u8>,
289        _cs: critical_section::CriticalSection,
290    ) -> Option<u8> {
291        // SAFETY: the critical section guarantees exclusive access, and the
292        // caller guarantees that the pointer is valid.
293        self.replace_exclusive(new_value)
294    }
295
296    /// Replace the value of this slot with `new_value`, and return
297    /// the old value.
298    ///
299    /// SAFETY: the pointer in this `SlotPtr` must be valid for writes, and the caller must guarantee exclusive
300    /// access to the underlying value..
301    unsafe fn replace_exclusive(&mut self, new_value: Option<u8>) -> Option<u8> {
302        // SAFETY: the caller has ensured that we have exclusive access & that
303        // the pointer is valid.
304        unsafe { core::ptr::replace(self.0, new_value) }
305    }
306}
307
308unsafe impl Send for SlotPtr {}
309
310unsafe impl Sync for SlotPtr {}
311
312impl<T, const N: usize> core::fmt::Debug for Sender<'_, T, N> {
313    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
314        write!(f, "Sender")
315    }
316}
317
318#[cfg(feature = "defmt-03")]
319impl<T, const N: usize> defmt::Format for Sender<'_, T, N> {
320    fn format(&self, f: defmt::Formatter) {
321        defmt::write!(f, "Sender",)
322    }
323}
324
325impl<T, const N: usize> Sender<'_, T, N> {
326    #[inline(always)]
327    fn send_footer(&mut self, idx: u8, val: T) {
328        // Write the value to the slots, note; this memcpy is not under a critical section.
329        unsafe {
330            let first_element = self.0.slots.get_unchecked(idx as usize).get_mut();
331            let ptr = first_element.deref().as_mut_ptr();
332            ptr::write(ptr, val)
333        }
334
335        // Write the value into the ready queue.
336        critical_section::with(|cs| {
337            // SAFETY: `self.0.readyq` is not called recursively.
338            unsafe {
339                self.0.readyq(cs, |readyq| {
340                    debug_assert!(!readyq.is_full());
341                    // SAFETY: ready is not full.
342                    readyq.push_back_unchecked(idx);
343                });
344            }
345        });
346
347        fence(Ordering::SeqCst);
348
349        // If there is a receiver waker, wake it.
350        self.0.receiver_waker.wake();
351    }
352
353    /// Try to send a value, non-blocking. If the channel is full this will return an error.
354    pub fn try_send(&mut self, val: T) -> Result<(), TrySendError<T>> {
355        // If the wait queue is not empty, we can't try to push into the queue.
356        if !self.0.wait_queue.is_empty() {
357            return Err(TrySendError::Full(val));
358        }
359
360        // No receiver available.
361        if self.is_closed() {
362            return Err(TrySendError::NoReceiver(val));
363        }
364
365        let free_slot = critical_section::with(|cs| unsafe {
366            // SAFETY: `self.0.freeq` is not called recursively.
367            self.0.freeq(cs, |q| q.pop_front())
368        });
369
370        let idx = if let Some(idx) = free_slot {
371            idx
372        } else {
373            return Err(TrySendError::Full(val));
374        };
375
376        self.send_footer(idx, val);
377
378        Ok(())
379    }
380
381    /// Send a value. If there is no place left in the queue this will wait until there is.
382    /// If the receiver does not exist this will return an error.
383    pub async fn send(&mut self, val: T) -> Result<(), NoReceiver<T>> {
384        let mut free_slot_ptr: Option<u8> = None;
385        let mut link_ptr: Option<Link<WaitQueueData>> = None;
386
387        // Make this future `Drop`-safe.
388        // SAFETY(link_ptr): Shadow the original definition of `link_ptr` so we can't abuse it.
389        let mut link_ptr = LinkPtr(core::ptr::addr_of_mut!(link_ptr));
390        // SAFETY(freed_slot): Shadow the original definition of `free_slot_ptr` so we can't abuse it.
391        let mut free_slot_ptr = SlotPtr(core::ptr::addr_of_mut!(free_slot_ptr));
392
393        let mut link_ptr2 = link_ptr.clone();
394        let mut free_slot_ptr2 = free_slot_ptr.clone();
395        let dropper = OnDrop::new(|| {
396            // SAFETY: We only run this closure and dereference the pointer if we have
397            // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference
398            // of this pointer is in the `poll_fn`.
399            if let Some(link) = unsafe { link_ptr2.get() } {
400                link.remove_from_list(&self.0.wait_queue);
401            }
402
403            // Return our potentially-unused free slot.
404            // Since we are certain that our link has been removed from the list (either
405            // pop-ed or removed just above), we have exclusive access to the free slot pointer.
406            if let Some(freed_slot) = unsafe { free_slot_ptr2.replace_exclusive(None) } {
407                // SAFETY: freed slot is passed to us from `return_free_slot`, which either
408                // directly (through `try_recv`), or indirectly (through another `return_free_slot`)
409                // comes from `readyq`.
410                unsafe { self.0.return_free_slot(freed_slot) };
411            }
412        });
413
414        let idx = poll_fn(|cx| {
415            //  Do all this in one critical section, else there can be race conditions
416            critical_section::with(|cs| {
417                if self.is_closed() {
418                    return Poll::Ready(Err(()));
419                }
420
421                let wq_empty = self.0.wait_queue.is_empty();
422                // SAFETY: `self.0.freeq` is not called recursively.
423                let freeq_empty = unsafe { self.0.freeq(cs, |q| q.is_empty()) };
424
425                // SAFETY: This pointer is only dereferenced here and on drop of the future
426                // which happens outside this `poll_fn`'s stack frame.
427                let link = unsafe { link_ptr.get() };
428
429                // We are already in the wait queue.
430                if let Some(queue_link) = link {
431                    if queue_link.is_popped() {
432                        // SAFETY: `free_slot_ptr` is valid for writes until the end of this future.
433                        let slot = unsafe { free_slot_ptr.replace(None, cs) };
434
435                        // Our link was popped, so it is most definitely not in the list.
436                        // We can safely & correctly `take` it to prevent ourselves from
437                        // redundantly attempting to remove it from the list a 2nd time.
438                        link.take();
439
440                        // If our link is popped, then:
441                        // 1. We were popped by `return_free_lot` and provided us with a slot.
442                        // 2. We were popped by `Receiver::drop` and it did not provide us with a slot, and the channel is closed.
443                        if let Some(slot) = slot {
444                            Poll::Ready(Ok(slot))
445                        } else {
446                            Poll::Ready(Err(()))
447                        }
448                    } else {
449                        Poll::Pending
450                    }
451                }
452                // We are not in the wait queue, but others are, or there is currently no free
453                // slot available.
454                else if !wq_empty || freeq_empty {
455                    // Place the link in the wait queue.
456                    let link_ref =
457                        link.insert(Link::new((cx.waker().clone(), free_slot_ptr.clone())));
458
459                    // SAFETY(new_unchecked): The address to the link is stable as it is defined
460                    // outside this stack frame.
461                    // SAFETY(push): `link_ref` lifetime comes from `link_ptr` and `free_slot_ptr` that
462                    // are shadowed and we make sure in `dropper` that the link is removed from the queue
463                    // before dropping `link_ptr` AND `dropper` makes sure that the shadowed
464                    // `ptr`s live until the end of the stack frame.
465                    unsafe { self.0.wait_queue.push(Pin::new_unchecked(link_ref)) };
466
467                    Poll::Pending
468                }
469                // We are not in the wait queue, no one else is waiting, and there is a free slot available.
470                else {
471                    // SAFETY: `self.0.freeq` is not called recursively.
472                    unsafe {
473                        self.0.freeq(cs, |freeq| {
474                            debug_assert!(!freeq.is_empty());
475                            // SAFETY: `freeq` is non-empty
476                            let slot = freeq.pop_back_unchecked();
477                            Poll::Ready(Ok(slot))
478                        })
479                    }
480                }
481            })
482        })
483        .await;
484
485        // Make sure the link is removed from the queue.
486        drop(dropper);
487
488        if let Ok(idx) = idx {
489            self.send_footer(idx, val);
490
491            Ok(())
492        } else {
493            Err(NoReceiver(val))
494        }
495    }
496
497    /// Returns true if there is no `Receiver`s.
498    pub fn is_closed(&self) -> bool {
499        critical_section::with(|cs| unsafe {
500            // SAFETY: `self.0.receiver_dropped` is not called recursively.
501            self.0.receiver_dropped(cs, |v| *v)
502        })
503    }
504
505    /// Is the queue full.
506    pub fn is_full(&self) -> bool {
507        critical_section::with(|cs| unsafe {
508            // SAFETY: `self.0.freeq` is not called recursively.
509            self.0.freeq(cs, |v| v.is_empty())
510        })
511    }
512
513    /// Is the queue empty.
514    pub fn is_empty(&self) -> bool {
515        critical_section::with(|cs| unsafe {
516            // SAFETY: `self.0.freeq` is not called recursively.
517            self.0.freeq(cs, |v| v.is_full())
518        })
519    }
520}
521
522impl<T, const N: usize> Drop for Sender<'_, T, N> {
523    fn drop(&mut self) {
524        // Count down the reference counter
525        let num_senders = critical_section::with(|cs| {
526            unsafe {
527                // SAFETY: `self.0.num_senders` is not called recursively.
528                self.0.num_senders(cs, |s| {
529                    *s -= 1;
530                    *s
531                })
532            }
533        });
534
535        // If there are no senders, wake the receiver to do error handling.
536        if num_senders == 0 {
537            self.0.receiver_waker.wake();
538        }
539    }
540}
541
542impl<T, const N: usize> Clone for Sender<'_, T, N> {
543    fn clone(&self) -> Self {
544        // Count up the reference counter
545        critical_section::with(|cs| unsafe {
546            // SAFETY: `self.0.num_senders` is not called recursively.
547            self.0.num_senders(cs, |v| *v += 1);
548        });
549
550        Self(self.0)
551    }
552}
553
554// -------- Receiver
555
556/// A receiver of the channel. There can only be one receiver at any time.
557pub struct Receiver<'a, T, const N: usize>(&'a Channel<T, N>);
558
559unsafe impl<T, const N: usize> Send for Receiver<'_, T, N> {}
560
561impl<T, const N: usize> core::fmt::Debug for Receiver<'_, T, N> {
562    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
563        write!(f, "Receiver")
564    }
565}
566
567#[cfg(feature = "defmt-03")]
568impl<T, const N: usize> defmt::Format for Receiver<'_, T, N> {
569    fn format(&self, f: defmt::Formatter) {
570        defmt::write!(f, "Receiver",)
571    }
572}
573
574/// Possible receive errors.
575#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
576#[derive(Debug, PartialEq, Eq, Clone, Copy)]
577pub enum ReceiveError {
578    /// Error state for when all senders has been dropped.
579    NoSender,
580    /// Error state for when the queue is empty.
581    Empty,
582}
583
584impl<T, const N: usize> Receiver<'_, T, N> {
585    /// Receives a value if there is one in the channel, non-blocking.
586    pub fn try_recv(&mut self) -> Result<T, ReceiveError> {
587        // Try to get a ready slot.
588        let ready_slot = critical_section::with(|cs| unsafe {
589            // SAFETY: `self.0.readyq` is not called recursively.
590            self.0.readyq(cs, |q| q.pop_front())
591        });
592
593        if let Some(rs) = ready_slot {
594            // Read the value from the slots, note; this memcpy is not under a critical section.
595            // SAFETY: `rs` is directly obtained from `self.0.readyq` and is read exactly
596            // once.
597            let r = unsafe { self.0.read_slot(rs) };
598
599            // Return the index to the free queue after we've read the value.
600            // SAFETY: `rs` comes directly from `readyq` and is only returned
601            // once.
602            unsafe { self.0.return_free_slot(rs) };
603
604            Ok(r)
605        } else if self.is_closed() {
606            Err(ReceiveError::NoSender)
607        } else {
608            Err(ReceiveError::Empty)
609        }
610    }
611
612    /// Receives a value, waiting if the queue is empty.
613    /// If all senders are dropped this will error with `NoSender`.
614    pub async fn recv(&mut self) -> Result<T, ReceiveError> {
615        // There was nothing in the queue, setup the waiting.
616        poll_fn(|cx| {
617            // Register waker.
618            // TODO: Should it happen here or after the if? This might cause a spurious wake.
619            self.0.receiver_waker.register(cx.waker());
620
621            // Try to dequeue.
622            match self.try_recv() {
623                Ok(val) => {
624                    return Poll::Ready(Ok(val));
625                }
626                Err(ReceiveError::NoSender) => {
627                    return Poll::Ready(Err(ReceiveError::NoSender));
628                }
629                _ => {}
630            }
631
632            Poll::Pending
633        })
634        .await
635    }
636
637    /// Returns true if there are no `Sender`s.
638    pub fn is_closed(&self) -> bool {
639        critical_section::with(|cs| unsafe {
640            // SAFETY: `self.0.num_senders` is not called recursively.
641            self.0.num_senders(cs, |v| *v == 0)
642        })
643    }
644
645    /// Is the queue full.
646    pub fn is_full(&self) -> bool {
647        critical_section::with(|cs| unsafe {
648            // SAFETY: `self.0.readyq` is not called recursively.
649            self.0.readyq(cs, |v| v.is_full())
650        })
651    }
652
653    /// Is the queue empty.
654    pub fn is_empty(&self) -> bool {
655        critical_section::with(|cs| unsafe {
656            // SAFETY: `self.0.readyq` is not called recursively.
657            self.0.readyq(cs, |v| v.is_empty())
658        })
659    }
660}
661
662impl<T, const N: usize> Drop for Receiver<'_, T, N> {
663    fn drop(&mut self) {
664        // Mark the receiver as dropped and wake all waiters
665        critical_section::with(|cs| unsafe {
666            // SAFETY: `self.0.receiver_dropped` is not called recursively.
667            self.0.receiver_dropped(cs, |v| *v = true);
668        });
669
670        let ready_slot = || {
671            critical_section::with(|cs| unsafe {
672                // SAFETY: `self.0.readyq` is not called recursively.
673                self.0.readyq(cs, |q| q.pop_back())
674            })
675        };
676
677        while let Some(slot) = ready_slot() {
678            // SAFETY: `slot` comes from `readyq` and is
679            // read exactly once.
680            drop(unsafe { self.0.read_slot(slot) })
681        }
682
683        while let Some((waker, _)) = self.0.wait_queue.pop() {
684            waker.wake();
685        }
686    }
687}
688
689#[cfg(test)]
690#[cfg(not(loom))]
691mod tests {
692    use core::sync::atomic::AtomicBool;
693    use std::sync::Arc;
694
695    use cassette::Cassette;
696
697    use super::*;
698
699    #[test]
700    fn empty() {
701        let (mut s, mut r) = make_channel!(u32, 10);
702
703        assert!(s.is_empty());
704        assert!(r.is_empty());
705
706        s.try_send(1).unwrap();
707
708        assert!(!s.is_empty());
709        assert!(!r.is_empty());
710
711        r.try_recv().unwrap();
712
713        assert!(s.is_empty());
714        assert!(r.is_empty());
715    }
716
717    #[test]
718    fn full() {
719        let (mut s, mut r) = make_channel!(u32, 3);
720
721        for _ in 0..3 {
722            assert!(!s.is_full());
723            assert!(!r.is_full());
724
725            s.try_send(1).unwrap();
726        }
727
728        assert!(s.is_full());
729        assert!(r.is_full());
730
731        for _ in 0..3 {
732            r.try_recv().unwrap();
733
734            assert!(!s.is_full());
735            assert!(!r.is_full());
736        }
737    }
738
739    #[test]
740    fn send_recieve() {
741        let (mut s, mut r) = make_channel!(u32, 10);
742
743        for i in 0..10 {
744            s.try_send(i).unwrap();
745        }
746
747        assert_eq!(s.try_send(11), Err(TrySendError::Full(11)));
748
749        for i in 0..10 {
750            assert_eq!(r.try_recv().unwrap(), i);
751        }
752
753        assert_eq!(r.try_recv(), Err(ReceiveError::Empty));
754    }
755
756    #[test]
757    fn closed_recv() {
758        let (s, mut r) = make_channel!(u32, 10);
759
760        drop(s);
761
762        assert!(r.is_closed());
763
764        assert_eq!(r.try_recv(), Err(ReceiveError::NoSender));
765    }
766
767    #[test]
768    fn closed_sender() {
769        let (mut s, r) = make_channel!(u32, 10);
770
771        drop(r);
772
773        assert!(s.is_closed());
774
775        assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11)));
776    }
777
778    fn make() {
779        let _ = make_channel!(u32, 10);
780    }
781
782    #[test]
783    #[should_panic]
784    fn double_make_channel() {
785        make();
786        make();
787    }
788
789    #[test]
790    fn tuple_channel() {
791        let _ = make_channel!((i32, u32), 10);
792    }
793
794    fn freeq<const N: usize, T, F, R>(channel: &Channel<T, N>, f: F) -> R
795    where
796        F: FnOnce(&mut Deque<u8, N>) -> R,
797    {
798        critical_section::with(|cs| unsafe { channel.freeq(cs, f) })
799    }
800
801    #[test]
802    fn dropping_waked_send_returns_freeq_item() {
803        let (mut tx, mut rx) = make_channel!(u8, 1);
804
805        tx.try_send(0).unwrap();
806        assert!(freeq(&rx.0, |q| q.is_empty()));
807
808        // Running this in a separate thread scope to ensure that `pinned_future` is dropped fully.
809        //
810        // Calling drop explicitly gets hairy because dropping things behind a `Pin` is not easy.
811        std::thread::scope(|scope| {
812            scope.spawn(|| {
813                let pinned_future = core::pin::pin!(tx.send(1));
814                let mut future = Cassette::new(pinned_future);
815
816                future.poll_on();
817
818                assert!(freeq(&rx.0, |q| q.is_empty()));
819                assert!(!rx.0.wait_queue.is_empty());
820
821                assert_eq!(rx.try_recv(), Ok(0));
822
823                assert!(freeq(&rx.0, |q| q.is_empty()));
824            });
825        });
826
827        assert!(!freeq(&rx.0, |q| q.is_empty()));
828
829        // Make sure that rx & tx are alive until here for good measure.
830        drop((tx, rx));
831    }
832
833    #[derive(Debug)]
834    struct SetToTrueOnDrop(Arc<AtomicBool>);
835
836    impl Drop for SetToTrueOnDrop {
837        fn drop(&mut self) {
838            self.0.store(true, Ordering::SeqCst);
839        }
840    }
841
842    #[test]
843    fn non_popped_item_is_dropped() {
844        let mut channel: Channel<SetToTrueOnDrop, 1> = Channel::new();
845
846        let (mut tx, rx) = channel.split();
847
848        let value = Arc::new(AtomicBool::new(false));
849        tx.try_send(SetToTrueOnDrop(value.clone())).unwrap();
850
851        drop((tx, rx));
852
853        assert!(value.load(Ordering::SeqCst));
854    }
855
856    #[test]
857    pub fn splitting_empty_channel_works() {
858        let mut channel: Channel<(), 1> = Channel::new();
859
860        let (mut tx, rx) = channel.split();
861
862        tx.try_send(()).unwrap();
863
864        drop((tx, rx));
865
866        channel.split();
867    }
868}
869
870#[cfg(not(loom))]
871#[cfg(test)]
872mod tokio_tests {
873    #[tokio::test]
874    async fn stress_channel() {
875        const NUM_RUNS: usize = 1_000;
876        const QUEUE_SIZE: usize = 10;
877
878        let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
879        let mut v = std::vec::Vec::new();
880
881        for i in 0..NUM_RUNS {
882            let mut s = s.clone();
883
884            v.push(tokio::spawn(async move {
885                s.send(i as _).await.unwrap();
886            }));
887        }
888
889        let mut map = std::collections::BTreeSet::new();
890
891        for _ in 0..NUM_RUNS {
892            map.insert(r.recv().await.unwrap());
893        }
894
895        assert_eq!(map.len(), NUM_RUNS);
896
897        for v in v {
898            v.await.unwrap();
899        }
900    }
901}
902
903#[cfg(test)]
904#[cfg(loom)]
905mod loom_test {
906    use cassette::Cassette;
907    use loom::thread;
908
909    #[macro_export]
910    #[allow(missing_docs)]
911    macro_rules! make_loom_channel {
912        ($type:ty, $size:expr) => {{
913            let channel: crate::channel::Channel<$type, $size> = super::Channel::new();
914            let boxed = Box::new(channel);
915            let boxed = Box::leak(boxed);
916
917            // SAFETY: This is safe as we hide the static mut from others to access it.
918            // Only this point is where the mutable access happens.
919            boxed.split()
920        }};
921    }
922
923    // This test tests the following scenarios:
924    // 1. Receiver is dropped while concurrent senders are waiting to send.
925    // 2. Concurrent senders are competing for the same free slot.
926    #[test]
927    pub fn concurrent_send_while_full_and_drop() {
928        loom::model(|| {
929            let (mut tx, mut rx) = make_loom_channel!([u8; 20], 1);
930            let mut cloned = tx.clone();
931
932            tx.try_send([1; 20]).unwrap();
933
934            let handle1 = thread::spawn(move || {
935                let future = std::pin::pin!(tx.send([1; 20]));
936                let mut future = Cassette::new(future);
937                if future.poll_on().is_none() {
938                    future.poll_on();
939                }
940            });
941
942            rx.try_recv().ok();
943
944            let future = std::pin::pin!(cloned.send([1; 20]));
945            let mut future = Cassette::new(future);
946            if future.poll_on().is_none() {
947                future.poll_on();
948            }
949
950            drop(rx);
951
952            handle1.join().unwrap();
953        });
954    }
955}