1use crate::unsafecell::UnsafeCell;
4use core::{
5 future::poll_fn,
6 mem::MaybeUninit,
7 pin::Pin,
8 ptr,
9 sync::atomic::{fence, Ordering},
10 task::{Poll, Waker},
11};
12#[doc(hidden)]
13pub use critical_section;
14use heapless::Deque;
15use rtic_common::{
16 dropper::OnDrop, wait_queue::DoublyLinkedList, wait_queue::Link,
17 waker_registration::CriticalSectionWakerRegistration as WakerRegistration,
18};
19
20#[cfg(feature = "defmt-03")]
21use crate::defmt;
22
23type WaitQueueData = (Waker, SlotPtr);
24type WaitQueue = DoublyLinkedList<WaitQueueData>;
25
26pub struct Channel<T, const N: usize> {
31 freeq: UnsafeCell<Deque<u8, N>>,
33 readyq: UnsafeCell<Deque<u8, N>>,
35 receiver_waker: WakerRegistration,
37 slots: [UnsafeCell<MaybeUninit<T>>; N],
39 wait_queue: WaitQueue,
41 receiver_dropped: UnsafeCell<bool>,
43 num_senders: UnsafeCell<usize>,
45}
46
47unsafe impl<T, const N: usize> Send for Channel<T, N> {}
48
49unsafe impl<T, const N: usize> Sync for Channel<T, N> {}
50
51macro_rules! cs_access {
52 ($name:ident, $type:ty) => {
53 unsafe fn $name<F, R>(&self, _cs: critical_section::CriticalSection, f: F) -> R
57 where
58 F: FnOnce(&mut $type) -> R,
59 {
60 let v = self.$name.get_mut();
61 let v = unsafe { v.deref() };
63 f(v)
64 }
65 };
66}
67
68impl<T, const N: usize> Default for Channel<T, N> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl<T, const N: usize> Channel<T, N> {
75 const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries");
76
77 #[cfg(not(loom))]
79 pub const fn new() -> Self {
80 Self {
81 freeq: UnsafeCell::new(Deque::new()),
82 readyq: UnsafeCell::new(Deque::new()),
83 receiver_waker: WakerRegistration::new(),
84 slots: [const { UnsafeCell::new(MaybeUninit::uninit()) }; N],
85 wait_queue: WaitQueue::new(),
86 receiver_dropped: UnsafeCell::new(false),
87 num_senders: UnsafeCell::new(0),
88 }
89 }
90
91 #[cfg(loom)]
93 pub fn new() -> Self {
94 Self {
95 freeq: UnsafeCell::new(Deque::new()),
96 readyq: UnsafeCell::new(Deque::new()),
97 receiver_waker: WakerRegistration::new(),
98 slots: core::array::from_fn(|_| UnsafeCell::new(MaybeUninit::uninit())),
99 wait_queue: WaitQueue::new(),
100 receiver_dropped: UnsafeCell::new(false),
101 num_senders: UnsafeCell::new(0),
102 }
103 }
104
105 pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) {
107 debug_assert!(self.readyq.as_mut().is_empty(),);
109
110 let freeq = self.freeq.as_mut();
111
112 freeq.clear();
113
114 for idx in 0..N as u8 {
116 debug_assert!(!freeq.is_full());
117
118 unsafe {
121 freeq.push_back_unchecked(idx);
122 }
123 }
124
125 debug_assert!(freeq.is_full());
126
127 *self.num_senders.as_mut() = 1;
129
130 (Sender(self), Receiver(self))
131 }
132
133 cs_access!(freeq, Deque<u8, N>);
134 cs_access!(readyq, Deque<u8, N>);
135 cs_access!(receiver_dropped, bool);
136 cs_access!(num_senders, usize);
137
138 unsafe fn return_free_slot(&self, slot: u8) {
147 critical_section::with(|cs| {
148 fence(Ordering::SeqCst);
149
150 if let Some((wait_head, mut freeq_slot)) = self.wait_queue.pop() {
152 unsafe { freeq_slot.replace(Some(slot), cs) };
155 wait_head.wake();
156 } else {
157 unsafe {
159 self.freeq(cs, |freeq| {
160 debug_assert!(!freeq.is_full());
161 freeq.push_back_unchecked(slot);
163 });
164 }
165 }
166 })
167 }
168
169 unsafe fn read_slot(&self, slot: u8) -> T {
172 let first_element = self.slots.get_unchecked(slot as usize).get_mut();
173 let ptr = first_element.deref().as_ptr();
174 ptr::read(ptr)
175 }
176}
177
178#[macro_export]
180#[cfg(not(loom))]
181macro_rules! make_channel {
182 ($type:ty, $size:expr) => {{
183 static mut CHANNEL: $crate::channel::Channel<$type, $size> =
184 $crate::channel::Channel::new();
185
186 static CHECK: $crate::portable_atomic::AtomicU8 = $crate::portable_atomic::AtomicU8::new(0);
187
188 $crate::channel::critical_section::with(|_| {
189 if CHECK.load(::core::sync::atomic::Ordering::Relaxed) != 0 {
190 panic!("call to the same `make_channel` instance twice");
191 }
192
193 CHECK.store(1, ::core::sync::atomic::Ordering::Relaxed);
194 });
195
196 #[allow(static_mut_refs)]
199 unsafe {
200 CHANNEL.split()
201 }
202 }};
203}
204
205#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
209pub struct NoReceiver<T>(pub T);
210
211#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
213pub enum TrySendError<T> {
214 NoReceiver(T),
216 Full(T),
218}
219
220impl<T> core::fmt::Debug for NoReceiver<T>
221where
222 T: core::fmt::Debug,
223{
224 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
225 write!(f, "NoReceiver({:?})", self.0)
226 }
227}
228
229impl<T> core::fmt::Debug for TrySendError<T>
230where
231 T: core::fmt::Debug,
232{
233 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234 match self {
235 TrySendError::NoReceiver(v) => write!(f, "NoReceiver({v:?})"),
236 TrySendError::Full(v) => write!(f, "Full({v:?})"),
237 }
238 }
239}
240
241impl<T> PartialEq for TrySendError<T>
242where
243 T: PartialEq,
244{
245 fn eq(&self, other: &Self) -> bool {
246 match (self, other) {
247 (TrySendError::NoReceiver(v1), TrySendError::NoReceiver(v2)) => v1.eq(v2),
248 (TrySendError::NoReceiver(_), TrySendError::Full(_)) => false,
249 (TrySendError::Full(_), TrySendError::NoReceiver(_)) => false,
250 (TrySendError::Full(v1), TrySendError::Full(v2)) => v1.eq(v2),
251 }
252 }
253}
254
255pub struct Sender<'a, T, const N: usize>(&'a Channel<T, N>);
257
258unsafe impl<T, const N: usize> Send for Sender<'_, T, N> {}
259
260#[derive(Clone)]
263struct LinkPtr(*mut Option<Link<WaitQueueData>>);
264
265impl LinkPtr {
266 unsafe fn get(&mut self) -> &mut Option<Link<WaitQueueData>> {
268 &mut *self.0
269 }
270}
271
272unsafe impl Send for LinkPtr {}
273
274unsafe impl Sync for LinkPtr {}
275
276#[derive(Clone)]
279struct SlotPtr(*mut Option<u8>);
280
281impl SlotPtr {
282 unsafe fn replace(
287 &mut self,
288 new_value: Option<u8>,
289 _cs: critical_section::CriticalSection,
290 ) -> Option<u8> {
291 self.replace_exclusive(new_value)
294 }
295
296 unsafe fn replace_exclusive(&mut self, new_value: Option<u8>) -> Option<u8> {
302 unsafe { core::ptr::replace(self.0, new_value) }
305 }
306}
307
308unsafe impl Send for SlotPtr {}
309
310unsafe impl Sync for SlotPtr {}
311
312impl<T, const N: usize> core::fmt::Debug for Sender<'_, T, N> {
313 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
314 write!(f, "Sender")
315 }
316}
317
318#[cfg(feature = "defmt-03")]
319impl<T, const N: usize> defmt::Format for Sender<'_, T, N> {
320 fn format(&self, f: defmt::Formatter) {
321 defmt::write!(f, "Sender",)
322 }
323}
324
325impl<T, const N: usize> Sender<'_, T, N> {
326 #[inline(always)]
327 fn send_footer(&mut self, idx: u8, val: T) {
328 unsafe {
330 let first_element = self.0.slots.get_unchecked(idx as usize).get_mut();
331 let ptr = first_element.deref().as_mut_ptr();
332 ptr::write(ptr, val)
333 }
334
335 critical_section::with(|cs| {
337 unsafe {
339 self.0.readyq(cs, |readyq| {
340 debug_assert!(!readyq.is_full());
341 readyq.push_back_unchecked(idx);
343 });
344 }
345 });
346
347 fence(Ordering::SeqCst);
348
349 self.0.receiver_waker.wake();
351 }
352
353 pub fn try_send(&mut self, val: T) -> Result<(), TrySendError<T>> {
355 if !self.0.wait_queue.is_empty() {
357 return Err(TrySendError::Full(val));
358 }
359
360 if self.is_closed() {
362 return Err(TrySendError::NoReceiver(val));
363 }
364
365 let free_slot = critical_section::with(|cs| unsafe {
366 self.0.freeq(cs, |q| q.pop_front())
368 });
369
370 let idx = if let Some(idx) = free_slot {
371 idx
372 } else {
373 return Err(TrySendError::Full(val));
374 };
375
376 self.send_footer(idx, val);
377
378 Ok(())
379 }
380
381 pub async fn send(&mut self, val: T) -> Result<(), NoReceiver<T>> {
384 let mut free_slot_ptr: Option<u8> = None;
385 let mut link_ptr: Option<Link<WaitQueueData>> = None;
386
387 let mut link_ptr = LinkPtr(core::ptr::addr_of_mut!(link_ptr));
390 let mut free_slot_ptr = SlotPtr(core::ptr::addr_of_mut!(free_slot_ptr));
392
393 let mut link_ptr2 = link_ptr.clone();
394 let mut free_slot_ptr2 = free_slot_ptr.clone();
395 let dropper = OnDrop::new(|| {
396 if let Some(link) = unsafe { link_ptr2.get() } {
400 link.remove_from_list(&self.0.wait_queue);
401 }
402
403 if let Some(freed_slot) = unsafe { free_slot_ptr2.replace_exclusive(None) } {
407 unsafe { self.0.return_free_slot(freed_slot) };
411 }
412 });
413
414 let idx = poll_fn(|cx| {
415 critical_section::with(|cs| {
417 if self.is_closed() {
418 return Poll::Ready(Err(()));
419 }
420
421 let wq_empty = self.0.wait_queue.is_empty();
422 let freeq_empty = unsafe { self.0.freeq(cs, |q| q.is_empty()) };
424
425 let link = unsafe { link_ptr.get() };
428
429 if let Some(queue_link) = link {
431 if queue_link.is_popped() {
432 let slot = unsafe { free_slot_ptr.replace(None, cs) };
434
435 link.take();
439
440 if let Some(slot) = slot {
444 Poll::Ready(Ok(slot))
445 } else {
446 Poll::Ready(Err(()))
447 }
448 } else {
449 Poll::Pending
450 }
451 }
452 else if !wq_empty || freeq_empty {
455 let link_ref =
457 link.insert(Link::new((cx.waker().clone(), free_slot_ptr.clone())));
458
459 unsafe { self.0.wait_queue.push(Pin::new_unchecked(link_ref)) };
466
467 Poll::Pending
468 }
469 else {
471 unsafe {
473 self.0.freeq(cs, |freeq| {
474 debug_assert!(!freeq.is_empty());
475 let slot = freeq.pop_back_unchecked();
477 Poll::Ready(Ok(slot))
478 })
479 }
480 }
481 })
482 })
483 .await;
484
485 drop(dropper);
487
488 if let Ok(idx) = idx {
489 self.send_footer(idx, val);
490
491 Ok(())
492 } else {
493 Err(NoReceiver(val))
494 }
495 }
496
497 pub fn is_closed(&self) -> bool {
499 critical_section::with(|cs| unsafe {
500 self.0.receiver_dropped(cs, |v| *v)
502 })
503 }
504
505 pub fn is_full(&self) -> bool {
507 critical_section::with(|cs| unsafe {
508 self.0.freeq(cs, |v| v.is_empty())
510 })
511 }
512
513 pub fn is_empty(&self) -> bool {
515 critical_section::with(|cs| unsafe {
516 self.0.freeq(cs, |v| v.is_full())
518 })
519 }
520}
521
522impl<T, const N: usize> Drop for Sender<'_, T, N> {
523 fn drop(&mut self) {
524 let num_senders = critical_section::with(|cs| {
526 unsafe {
527 self.0.num_senders(cs, |s| {
529 *s -= 1;
530 *s
531 })
532 }
533 });
534
535 if num_senders == 0 {
537 self.0.receiver_waker.wake();
538 }
539 }
540}
541
542impl<T, const N: usize> Clone for Sender<'_, T, N> {
543 fn clone(&self) -> Self {
544 critical_section::with(|cs| unsafe {
546 self.0.num_senders(cs, |v| *v += 1);
548 });
549
550 Self(self.0)
551 }
552}
553
554pub struct Receiver<'a, T, const N: usize>(&'a Channel<T, N>);
558
559unsafe impl<T, const N: usize> Send for Receiver<'_, T, N> {}
560
561impl<T, const N: usize> core::fmt::Debug for Receiver<'_, T, N> {
562 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
563 write!(f, "Receiver")
564 }
565}
566
567#[cfg(feature = "defmt-03")]
568impl<T, const N: usize> defmt::Format for Receiver<'_, T, N> {
569 fn format(&self, f: defmt::Formatter) {
570 defmt::write!(f, "Receiver",)
571 }
572}
573
574#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
576#[derive(Debug, PartialEq, Eq, Clone, Copy)]
577pub enum ReceiveError {
578 NoSender,
580 Empty,
582}
583
584impl<T, const N: usize> Receiver<'_, T, N> {
585 pub fn try_recv(&mut self) -> Result<T, ReceiveError> {
587 let ready_slot = critical_section::with(|cs| unsafe {
589 self.0.readyq(cs, |q| q.pop_front())
591 });
592
593 if let Some(rs) = ready_slot {
594 let r = unsafe { self.0.read_slot(rs) };
598
599 unsafe { self.0.return_free_slot(rs) };
603
604 Ok(r)
605 } else if self.is_closed() {
606 Err(ReceiveError::NoSender)
607 } else {
608 Err(ReceiveError::Empty)
609 }
610 }
611
612 pub async fn recv(&mut self) -> Result<T, ReceiveError> {
615 poll_fn(|cx| {
617 self.0.receiver_waker.register(cx.waker());
620
621 match self.try_recv() {
623 Ok(val) => {
624 return Poll::Ready(Ok(val));
625 }
626 Err(ReceiveError::NoSender) => {
627 return Poll::Ready(Err(ReceiveError::NoSender));
628 }
629 _ => {}
630 }
631
632 Poll::Pending
633 })
634 .await
635 }
636
637 pub fn is_closed(&self) -> bool {
639 critical_section::with(|cs| unsafe {
640 self.0.num_senders(cs, |v| *v == 0)
642 })
643 }
644
645 pub fn is_full(&self) -> bool {
647 critical_section::with(|cs| unsafe {
648 self.0.readyq(cs, |v| v.is_full())
650 })
651 }
652
653 pub fn is_empty(&self) -> bool {
655 critical_section::with(|cs| unsafe {
656 self.0.readyq(cs, |v| v.is_empty())
658 })
659 }
660}
661
662impl<T, const N: usize> Drop for Receiver<'_, T, N> {
663 fn drop(&mut self) {
664 critical_section::with(|cs| unsafe {
666 self.0.receiver_dropped(cs, |v| *v = true);
668 });
669
670 let ready_slot = || {
671 critical_section::with(|cs| unsafe {
672 self.0.readyq(cs, |q| q.pop_back())
674 })
675 };
676
677 while let Some(slot) = ready_slot() {
678 drop(unsafe { self.0.read_slot(slot) })
681 }
682
683 while let Some((waker, _)) = self.0.wait_queue.pop() {
684 waker.wake();
685 }
686 }
687}
688
689#[cfg(test)]
690#[cfg(not(loom))]
691mod tests {
692 use core::sync::atomic::AtomicBool;
693 use std::sync::Arc;
694
695 use cassette::Cassette;
696
697 use super::*;
698
699 #[test]
700 fn empty() {
701 let (mut s, mut r) = make_channel!(u32, 10);
702
703 assert!(s.is_empty());
704 assert!(r.is_empty());
705
706 s.try_send(1).unwrap();
707
708 assert!(!s.is_empty());
709 assert!(!r.is_empty());
710
711 r.try_recv().unwrap();
712
713 assert!(s.is_empty());
714 assert!(r.is_empty());
715 }
716
717 #[test]
718 fn full() {
719 let (mut s, mut r) = make_channel!(u32, 3);
720
721 for _ in 0..3 {
722 assert!(!s.is_full());
723 assert!(!r.is_full());
724
725 s.try_send(1).unwrap();
726 }
727
728 assert!(s.is_full());
729 assert!(r.is_full());
730
731 for _ in 0..3 {
732 r.try_recv().unwrap();
733
734 assert!(!s.is_full());
735 assert!(!r.is_full());
736 }
737 }
738
739 #[test]
740 fn send_recieve() {
741 let (mut s, mut r) = make_channel!(u32, 10);
742
743 for i in 0..10 {
744 s.try_send(i).unwrap();
745 }
746
747 assert_eq!(s.try_send(11), Err(TrySendError::Full(11)));
748
749 for i in 0..10 {
750 assert_eq!(r.try_recv().unwrap(), i);
751 }
752
753 assert_eq!(r.try_recv(), Err(ReceiveError::Empty));
754 }
755
756 #[test]
757 fn closed_recv() {
758 let (s, mut r) = make_channel!(u32, 10);
759
760 drop(s);
761
762 assert!(r.is_closed());
763
764 assert_eq!(r.try_recv(), Err(ReceiveError::NoSender));
765 }
766
767 #[test]
768 fn closed_sender() {
769 let (mut s, r) = make_channel!(u32, 10);
770
771 drop(r);
772
773 assert!(s.is_closed());
774
775 assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11)));
776 }
777
778 fn make() {
779 let _ = make_channel!(u32, 10);
780 }
781
782 #[test]
783 #[should_panic]
784 fn double_make_channel() {
785 make();
786 make();
787 }
788
789 #[test]
790 fn tuple_channel() {
791 let _ = make_channel!((i32, u32), 10);
792 }
793
794 fn freeq<const N: usize, T, F, R>(channel: &Channel<T, N>, f: F) -> R
795 where
796 F: FnOnce(&mut Deque<u8, N>) -> R,
797 {
798 critical_section::with(|cs| unsafe { channel.freeq(cs, f) })
799 }
800
801 #[test]
802 fn dropping_waked_send_returns_freeq_item() {
803 let (mut tx, mut rx) = make_channel!(u8, 1);
804
805 tx.try_send(0).unwrap();
806 assert!(freeq(&rx.0, |q| q.is_empty()));
807
808 std::thread::scope(|scope| {
812 scope.spawn(|| {
813 let pinned_future = core::pin::pin!(tx.send(1));
814 let mut future = Cassette::new(pinned_future);
815
816 future.poll_on();
817
818 assert!(freeq(&rx.0, |q| q.is_empty()));
819 assert!(!rx.0.wait_queue.is_empty());
820
821 assert_eq!(rx.try_recv(), Ok(0));
822
823 assert!(freeq(&rx.0, |q| q.is_empty()));
824 });
825 });
826
827 assert!(!freeq(&rx.0, |q| q.is_empty()));
828
829 drop((tx, rx));
831 }
832
833 #[derive(Debug)]
834 struct SetToTrueOnDrop(Arc<AtomicBool>);
835
836 impl Drop for SetToTrueOnDrop {
837 fn drop(&mut self) {
838 self.0.store(true, Ordering::SeqCst);
839 }
840 }
841
842 #[test]
843 fn non_popped_item_is_dropped() {
844 let mut channel: Channel<SetToTrueOnDrop, 1> = Channel::new();
845
846 let (mut tx, rx) = channel.split();
847
848 let value = Arc::new(AtomicBool::new(false));
849 tx.try_send(SetToTrueOnDrop(value.clone())).unwrap();
850
851 drop((tx, rx));
852
853 assert!(value.load(Ordering::SeqCst));
854 }
855
856 #[test]
857 pub fn splitting_empty_channel_works() {
858 let mut channel: Channel<(), 1> = Channel::new();
859
860 let (mut tx, rx) = channel.split();
861
862 tx.try_send(()).unwrap();
863
864 drop((tx, rx));
865
866 channel.split();
867 }
868}
869
870#[cfg(not(loom))]
871#[cfg(test)]
872mod tokio_tests {
873 #[tokio::test]
874 async fn stress_channel() {
875 const NUM_RUNS: usize = 1_000;
876 const QUEUE_SIZE: usize = 10;
877
878 let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
879 let mut v = std::vec::Vec::new();
880
881 for i in 0..NUM_RUNS {
882 let mut s = s.clone();
883
884 v.push(tokio::spawn(async move {
885 s.send(i as _).await.unwrap();
886 }));
887 }
888
889 let mut map = std::collections::BTreeSet::new();
890
891 for _ in 0..NUM_RUNS {
892 map.insert(r.recv().await.unwrap());
893 }
894
895 assert_eq!(map.len(), NUM_RUNS);
896
897 for v in v {
898 v.await.unwrap();
899 }
900 }
901}
902
903#[cfg(test)]
904#[cfg(loom)]
905mod loom_test {
906 use cassette::Cassette;
907 use loom::thread;
908
909 #[macro_export]
910 #[allow(missing_docs)]
911 macro_rules! make_loom_channel {
912 ($type:ty, $size:expr) => {{
913 let channel: crate::channel::Channel<$type, $size> = super::Channel::new();
914 let boxed = Box::new(channel);
915 let boxed = Box::leak(boxed);
916
917 boxed.split()
920 }};
921 }
922
923 #[test]
927 pub fn concurrent_send_while_full_and_drop() {
928 loom::model(|| {
929 let (mut tx, mut rx) = make_loom_channel!([u8; 20], 1);
930 let mut cloned = tx.clone();
931
932 tx.try_send([1; 20]).unwrap();
933
934 let handle1 = thread::spawn(move || {
935 let future = std::pin::pin!(tx.send([1; 20]));
936 let mut future = Cassette::new(future);
937 if future.poll_on().is_none() {
938 future.poll_on();
939 }
940 });
941
942 rx.try_recv().ok();
943
944 let future = std::pin::pin!(cloned.send([1; 20]));
945 let mut future = Cassette::new(future);
946 if future.poll_on().is_none() {
947 future.poll_on();
948 }
949
950 drop(rx);
951
952 handle1.join().unwrap();
953 });
954 }
955}