1use crate::unsafecell::UnsafeCell;
4use core::{
5 future::poll_fn,
6 mem::MaybeUninit,
7 pin::Pin,
8 ptr,
9 sync::atomic::{fence, Ordering},
10 task::{Poll, Waker},
11};
12#[doc(hidden)]
13pub use critical_section;
14use heapless::Deque;
15use rtic_common::{
16 dropper::OnDrop, wait_queue::DoublyLinkedList, wait_queue::Link,
17 waker_registration::CriticalSectionWakerRegistration as WakerRegistration,
18};
19
20#[cfg(feature = "defmt-03")]
21use crate::defmt;
22
23type WaitQueueData = (Waker, SlotPtr);
24type WaitQueue = DoublyLinkedList<WaitQueueData>;
25
26pub struct Channel<T, const N: usize> {
31 freeq: UnsafeCell<Deque<u8, N>>,
33 readyq: UnsafeCell<Deque<u8, N>>,
35 receiver_waker: WakerRegistration,
37 slots: [UnsafeCell<MaybeUninit<T>>; N],
39 wait_queue: WaitQueue,
41 receiver_dropped: UnsafeCell<bool>,
43 num_senders: UnsafeCell<usize>,
45}
46
47unsafe impl<T, const N: usize> Send for Channel<T, N> {}
48
49unsafe impl<T, const N: usize> Sync for Channel<T, N> {}
50
51macro_rules! cs_access {
52 ($name:ident, $type:ty) => {
53 unsafe fn $name<F, R>(&self, _cs: critical_section::CriticalSection, f: F) -> R
57 where
58 F: FnOnce(&mut $type) -> R,
59 {
60 let v = self.$name.get_mut();
61 let v = unsafe { v.deref() };
63 f(v)
64 }
65 };
66}
67
68impl<T, const N: usize> Default for Channel<T, N> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl<T, const N: usize> Channel<T, N> {
75 const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries");
76
77 #[cfg(not(loom))]
79 pub const fn new() -> Self {
80 Self {
81 freeq: UnsafeCell::new(Deque::new()),
82 readyq: UnsafeCell::new(Deque::new()),
83 receiver_waker: WakerRegistration::new(),
84 slots: [const { UnsafeCell::new(MaybeUninit::uninit()) }; N],
85 wait_queue: WaitQueue::new(),
86 receiver_dropped: UnsafeCell::new(false),
87 num_senders: UnsafeCell::new(0),
88 }
89 }
90
91 #[cfg(loom)]
93 pub fn new() -> Self {
94 Self {
95 freeq: UnsafeCell::new(Deque::new()),
96 readyq: UnsafeCell::new(Deque::new()),
97 receiver_waker: WakerRegistration::new(),
98 slots: core::array::from_fn(|_| UnsafeCell::new(MaybeUninit::uninit())),
99 wait_queue: WaitQueue::new(),
100 receiver_dropped: UnsafeCell::new(false),
101 num_senders: UnsafeCell::new(0),
102 }
103 }
104
105 pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) {
107 let freeq = self.freeq.as_mut();
108
109 for idx in 0..N as u8 {
111 assert!(!freeq.is_full());
114
115 unsafe {
117 freeq.push_back_unchecked(idx);
118 }
119 }
120
121 debug_assert!(freeq.is_full());
122
123 *self.num_senders.as_mut() = 1;
125
126 (Sender(self), Receiver(self))
127 }
128
129 cs_access!(freeq, Deque<u8, N>);
130 cs_access!(readyq, Deque<u8, N>);
131 cs_access!(receiver_dropped, bool);
132 cs_access!(num_senders, usize);
133
134 unsafe fn return_free_slot(&self, slot: u8) {
142 critical_section::with(|cs| {
143 fence(Ordering::SeqCst);
144
145 if let Some((wait_head, mut freeq_slot)) = self.wait_queue.pop() {
147 unsafe { freeq_slot.replace(Some(slot), cs) };
150 wait_head.wake();
151 } else {
152 unsafe {
154 self.freeq(cs, |freeq| {
155 assert!(!freeq.is_full());
156 freeq.push_back_unchecked(slot);
158 });
159 }
160 }
161 })
162 }
163}
164
165#[macro_export]
167#[cfg(not(loom))]
168macro_rules! make_channel {
169 ($type:ty, $size:expr) => {{
170 static mut CHANNEL: $crate::channel::Channel<$type, $size> =
171 $crate::channel::Channel::new();
172
173 static CHECK: $crate::portable_atomic::AtomicU8 = $crate::portable_atomic::AtomicU8::new(0);
174
175 $crate::channel::critical_section::with(|_| {
176 if CHECK.load(::core::sync::atomic::Ordering::Relaxed) != 0 {
177 panic!("call to the same `make_channel` instance twice");
178 }
179
180 CHECK.store(1, ::core::sync::atomic::Ordering::Relaxed);
181 });
182
183 #[allow(static_mut_refs)]
186 unsafe {
187 CHANNEL.split()
188 }
189 }};
190}
191
192#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
196pub struct NoReceiver<T>(pub T);
197
198#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
200pub enum TrySendError<T> {
201 NoReceiver(T),
203 Full(T),
205}
206
207impl<T> core::fmt::Debug for NoReceiver<T>
208where
209 T: core::fmt::Debug,
210{
211 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
212 write!(f, "NoReceiver({:?})", self.0)
213 }
214}
215
216impl<T> core::fmt::Debug for TrySendError<T>
217where
218 T: core::fmt::Debug,
219{
220 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
221 match self {
222 TrySendError::NoReceiver(v) => write!(f, "NoReceiver({v:?})"),
223 TrySendError::Full(v) => write!(f, "Full({v:?})"),
224 }
225 }
226}
227
228impl<T> PartialEq for TrySendError<T>
229where
230 T: PartialEq,
231{
232 fn eq(&self, other: &Self) -> bool {
233 match (self, other) {
234 (TrySendError::NoReceiver(v1), TrySendError::NoReceiver(v2)) => v1.eq(v2),
235 (TrySendError::NoReceiver(_), TrySendError::Full(_)) => false,
236 (TrySendError::Full(_), TrySendError::NoReceiver(_)) => false,
237 (TrySendError::Full(v1), TrySendError::Full(v2)) => v1.eq(v2),
238 }
239 }
240}
241
242pub struct Sender<'a, T, const N: usize>(&'a Channel<T, N>);
244
245unsafe impl<T, const N: usize> Send for Sender<'_, T, N> {}
246
247#[derive(Clone)]
250struct LinkPtr(*mut Option<Link<WaitQueueData>>);
251
252impl LinkPtr {
253 unsafe fn get(&mut self) -> &mut Option<Link<WaitQueueData>> {
255 &mut *self.0
256 }
257}
258
259unsafe impl Send for LinkPtr {}
260
261unsafe impl Sync for LinkPtr {}
262
263#[derive(Clone)]
266struct SlotPtr(*mut Option<u8>);
267
268impl SlotPtr {
269 unsafe fn replace(
274 &mut self,
275 new_value: Option<u8>,
276 _cs: critical_section::CriticalSection,
277 ) -> Option<u8> {
278 self.replace_exclusive(new_value)
281 }
282
283 unsafe fn replace_exclusive(&mut self, new_value: Option<u8>) -> Option<u8> {
289 unsafe { core::ptr::replace(self.0, new_value) }
292 }
293}
294
295unsafe impl Send for SlotPtr {}
296
297unsafe impl Sync for SlotPtr {}
298
299impl<T, const N: usize> core::fmt::Debug for Sender<'_, T, N> {
300 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
301 write!(f, "Sender")
302 }
303}
304
305#[cfg(feature = "defmt-03")]
306impl<T, const N: usize> defmt::Format for Sender<'_, T, N> {
307 fn format(&self, f: defmt::Formatter) {
308 defmt::write!(f, "Sender",)
309 }
310}
311
312impl<T, const N: usize> Sender<'_, T, N> {
313 #[inline(always)]
314 fn send_footer(&mut self, idx: u8, val: T) {
315 unsafe {
317 let first_element = self.0.slots.get_unchecked(idx as usize).get_mut();
318 let ptr = first_element.deref().as_mut_ptr();
319 ptr::write(ptr, val)
320 }
321
322 critical_section::with(|cs| {
324 unsafe {
326 self.0.readyq(cs, |readyq| {
327 assert!(!readyq.is_full());
328 readyq.push_back_unchecked(idx);
330 });
331 }
332 });
333
334 fence(Ordering::SeqCst);
335
336 self.0.receiver_waker.wake();
338 }
339
340 pub fn try_send(&mut self, val: T) -> Result<(), TrySendError<T>> {
342 if !self.0.wait_queue.is_empty() {
344 return Err(TrySendError::Full(val));
345 }
346
347 if self.is_closed() {
349 return Err(TrySendError::NoReceiver(val));
350 }
351
352 let free_slot = critical_section::with(|cs| unsafe {
353 self.0.freeq(cs, |q| q.pop_front())
355 });
356
357 let idx = if let Some(idx) = free_slot {
358 idx
359 } else {
360 return Err(TrySendError::Full(val));
361 };
362
363 self.send_footer(idx, val);
364
365 Ok(())
366 }
367
368 pub async fn send(&mut self, val: T) -> Result<(), NoReceiver<T>> {
371 let mut free_slot_ptr: Option<u8> = None;
372 let mut link_ptr: Option<Link<WaitQueueData>> = None;
373
374 let mut link_ptr = LinkPtr(core::ptr::addr_of_mut!(link_ptr));
377 let mut free_slot_ptr = SlotPtr(core::ptr::addr_of_mut!(free_slot_ptr));
379
380 let mut link_ptr2 = link_ptr.clone();
381 let mut free_slot_ptr2 = free_slot_ptr.clone();
382 let dropper = OnDrop::new(|| {
383 if let Some(link) = unsafe { link_ptr2.get() } {
387 link.remove_from_list(&self.0.wait_queue);
388 }
389
390 if let Some(freed_slot) = unsafe { free_slot_ptr2.replace_exclusive(None) } {
394 unsafe { self.0.return_free_slot(freed_slot) };
398 }
399 });
400
401 let idx = poll_fn(|cx| {
402 critical_section::with(|cs| {
404 if self.is_closed() {
405 return Poll::Ready(Err(()));
406 }
407
408 let wq_empty = self.0.wait_queue.is_empty();
409 let freeq_empty = unsafe { self.0.freeq(cs, |q| q.is_empty()) };
411
412 let link = unsafe { link_ptr.get() };
415
416 if let Some(queue_link) = link {
418 if queue_link.is_popped() {
419 let slot = unsafe { free_slot_ptr.replace(None, cs) };
421
422 link.take();
426
427 if let Some(slot) = slot {
431 Poll::Ready(Ok(slot))
432 } else {
433 Poll::Ready(Err(()))
434 }
435 } else {
436 Poll::Pending
437 }
438 }
439 else if !wq_empty || freeq_empty {
442 let link_ref =
444 link.insert(Link::new((cx.waker().clone(), free_slot_ptr.clone())));
445
446 unsafe { self.0.wait_queue.push(Pin::new_unchecked(link_ref)) };
453
454 Poll::Pending
455 }
456 else {
458 unsafe {
460 self.0.freeq(cs, |freeq| {
461 assert!(!freeq.is_empty());
462 let slot = freeq.pop_back_unchecked();
464 Poll::Ready(Ok(slot))
465 })
466 }
467 }
468 })
469 })
470 .await;
471
472 drop(dropper);
474
475 if let Ok(idx) = idx {
476 self.send_footer(idx, val);
477
478 Ok(())
479 } else {
480 Err(NoReceiver(val))
481 }
482 }
483
484 pub fn is_closed(&self) -> bool {
486 critical_section::with(|cs| unsafe {
487 self.0.receiver_dropped(cs, |v| *v)
489 })
490 }
491
492 pub fn is_full(&self) -> bool {
494 critical_section::with(|cs| unsafe {
495 self.0.freeq(cs, |v| v.is_empty())
497 })
498 }
499
500 pub fn is_empty(&self) -> bool {
502 critical_section::with(|cs| unsafe {
503 self.0.freeq(cs, |v| v.is_full())
505 })
506 }
507}
508
509impl<T, const N: usize> Drop for Sender<'_, T, N> {
510 fn drop(&mut self) {
511 let num_senders = critical_section::with(|cs| {
513 unsafe {
514 self.0.num_senders(cs, |s| {
516 *s -= 1;
517 *s
518 })
519 }
520 });
521
522 if num_senders == 0 {
524 self.0.receiver_waker.wake();
525 }
526 }
527}
528
529impl<T, const N: usize> Clone for Sender<'_, T, N> {
530 fn clone(&self) -> Self {
531 critical_section::with(|cs| unsafe {
533 self.0.num_senders(cs, |v| *v += 1);
535 });
536
537 Self(self.0)
538 }
539}
540
541pub struct Receiver<'a, T, const N: usize>(&'a Channel<T, N>);
545
546unsafe impl<T, const N: usize> Send for Receiver<'_, T, N> {}
547
548impl<T, const N: usize> core::fmt::Debug for Receiver<'_, T, N> {
549 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
550 write!(f, "Receiver")
551 }
552}
553
554#[cfg(feature = "defmt-03")]
555impl<T, const N: usize> defmt::Format for Receiver<'_, T, N> {
556 fn format(&self, f: defmt::Formatter) {
557 defmt::write!(f, "Receiver",)
558 }
559}
560
561#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
563#[derive(Debug, PartialEq, Eq, Clone, Copy)]
564pub enum ReceiveError {
565 NoSender,
567 Empty,
569}
570
571impl<T, const N: usize> Receiver<'_, T, N> {
572 pub fn try_recv(&mut self) -> Result<T, ReceiveError> {
574 let ready_slot = critical_section::with(|cs| unsafe {
576 self.0.readyq(cs, |q| q.pop_front())
578 });
579
580 if let Some(rs) = ready_slot {
581 let r = unsafe {
583 let first_element = self.0.slots.get_unchecked(rs as usize).get_mut();
584 let ptr = first_element.deref().as_ptr();
585 ptr::read(ptr)
586 };
587
588 unsafe { self.0.return_free_slot(rs) };
591
592 Ok(r)
593 } else if self.is_closed() {
594 Err(ReceiveError::NoSender)
595 } else {
596 Err(ReceiveError::Empty)
597 }
598 }
599
600 pub async fn recv(&mut self) -> Result<T, ReceiveError> {
603 poll_fn(|cx| {
605 self.0.receiver_waker.register(cx.waker());
608
609 match self.try_recv() {
611 Ok(val) => {
612 return Poll::Ready(Ok(val));
613 }
614 Err(ReceiveError::NoSender) => {
615 return Poll::Ready(Err(ReceiveError::NoSender));
616 }
617 _ => {}
618 }
619
620 Poll::Pending
621 })
622 .await
623 }
624
625 pub fn is_closed(&self) -> bool {
627 critical_section::with(|cs| unsafe {
628 self.0.num_senders(cs, |v| *v == 0)
630 })
631 }
632
633 pub fn is_full(&self) -> bool {
635 critical_section::with(|cs| unsafe {
636 self.0.readyq(cs, |v| v.is_full())
638 })
639 }
640
641 pub fn is_empty(&self) -> bool {
643 critical_section::with(|cs| unsafe {
644 self.0.readyq(cs, |v| v.is_empty())
646 })
647 }
648}
649
650impl<T, const N: usize> Drop for Receiver<'_, T, N> {
651 fn drop(&mut self) {
652 critical_section::with(|cs| unsafe {
654 self.0.receiver_dropped(cs, |v| *v = true);
656 });
657
658 while let Some((waker, _)) = self.0.wait_queue.pop() {
659 waker.wake();
660 }
661 }
662}
663
664#[cfg(test)]
665#[cfg(not(loom))]
666mod tests {
667 use cassette::Cassette;
668
669 use super::*;
670
671 #[test]
672 fn empty() {
673 let (mut s, mut r) = make_channel!(u32, 10);
674
675 assert!(s.is_empty());
676 assert!(r.is_empty());
677
678 s.try_send(1).unwrap();
679
680 assert!(!s.is_empty());
681 assert!(!r.is_empty());
682
683 r.try_recv().unwrap();
684
685 assert!(s.is_empty());
686 assert!(r.is_empty());
687 }
688
689 #[test]
690 fn full() {
691 let (mut s, mut r) = make_channel!(u32, 3);
692
693 for _ in 0..3 {
694 assert!(!s.is_full());
695 assert!(!r.is_full());
696
697 s.try_send(1).unwrap();
698 }
699
700 assert!(s.is_full());
701 assert!(r.is_full());
702
703 for _ in 0..3 {
704 r.try_recv().unwrap();
705
706 assert!(!s.is_full());
707 assert!(!r.is_full());
708 }
709 }
710
711 #[test]
712 fn send_recieve() {
713 let (mut s, mut r) = make_channel!(u32, 10);
714
715 for i in 0..10 {
716 s.try_send(i).unwrap();
717 }
718
719 assert_eq!(s.try_send(11), Err(TrySendError::Full(11)));
720
721 for i in 0..10 {
722 assert_eq!(r.try_recv().unwrap(), i);
723 }
724
725 assert_eq!(r.try_recv(), Err(ReceiveError::Empty));
726 }
727
728 #[test]
729 fn closed_recv() {
730 let (s, mut r) = make_channel!(u32, 10);
731
732 drop(s);
733
734 assert!(r.is_closed());
735
736 assert_eq!(r.try_recv(), Err(ReceiveError::NoSender));
737 }
738
739 #[test]
740 fn closed_sender() {
741 let (mut s, r) = make_channel!(u32, 10);
742
743 drop(r);
744
745 assert!(s.is_closed());
746
747 assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11)));
748 }
749
750 fn make() {
751 let _ = make_channel!(u32, 10);
752 }
753
754 #[test]
755 #[should_panic]
756 fn double_make_channel() {
757 make();
758 make();
759 }
760
761 #[test]
762 fn tuple_channel() {
763 let _ = make_channel!((i32, u32), 10);
764 }
765
766 fn freeq<const N: usize, T, F, R>(channel: &Channel<T, N>, f: F) -> R
767 where
768 F: FnOnce(&mut Deque<u8, N>) -> R,
769 {
770 critical_section::with(|cs| unsafe { channel.freeq(cs, f) })
771 }
772
773 #[test]
774 fn dropping_waked_send_returns_freeq_item() {
775 let (mut tx, mut rx) = make_channel!(u8, 1);
776
777 tx.try_send(0).unwrap();
778 assert!(freeq(&rx.0, |q| q.is_empty()));
779
780 std::thread::scope(|scope| {
784 scope.spawn(|| {
785 let pinned_future = core::pin::pin!(tx.send(1));
786 let mut future = Cassette::new(pinned_future);
787
788 future.poll_on();
789
790 assert!(freeq(&rx.0, |q| q.is_empty()));
791 assert!(!rx.0.wait_queue.is_empty());
792
793 assert_eq!(rx.try_recv(), Ok(0));
794
795 assert!(freeq(&rx.0, |q| q.is_empty()));
796 });
797 });
798
799 assert!(!freeq(&rx.0, |q| q.is_empty()));
800
801 drop((tx, rx));
803 }
804}
805
806#[cfg(not(loom))]
807#[cfg(test)]
808mod tokio_tests {
809 #[tokio::test]
810 async fn stress_channel() {
811 const NUM_RUNS: usize = 1_000;
812 const QUEUE_SIZE: usize = 10;
813
814 let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
815 let mut v = std::vec::Vec::new();
816
817 for i in 0..NUM_RUNS {
818 let mut s = s.clone();
819
820 v.push(tokio::spawn(async move {
821 s.send(i as _).await.unwrap();
822 }));
823 }
824
825 let mut map = std::collections::BTreeSet::new();
826
827 for _ in 0..NUM_RUNS {
828 map.insert(r.recv().await.unwrap());
829 }
830
831 assert_eq!(map.len(), NUM_RUNS);
832
833 for v in v {
834 v.await.unwrap();
835 }
836 }
837}
838
839#[cfg(test)]
840#[cfg(loom)]
841mod loom_test {
842 use cassette::Cassette;
843 use loom::thread;
844
845 #[macro_export]
846 #[allow(missing_docs)]
847 macro_rules! make_loom_channel {
848 ($type:ty, $size:expr) => {{
849 let channel: crate::channel::Channel<$type, $size> = super::Channel::new();
850 let boxed = Box::new(channel);
851 let boxed = Box::leak(boxed);
852
853 boxed.split()
856 }};
857 }
858
859 #[test]
863 pub fn concurrent_send_while_full_and_drop() {
864 loom::model(|| {
865 let (mut tx, mut rx) = make_loom_channel!([u8; 20], 1);
866 let mut cloned = tx.clone();
867
868 tx.try_send([1; 20]).unwrap();
869
870 let handle1 = thread::spawn(move || {
871 let future = std::pin::pin!(tx.send([1; 20]));
872 let mut future = Cassette::new(future);
873 if future.poll_on().is_none() {
874 future.poll_on();
875 }
876 });
877
878 rx.try_recv().ok();
879
880 let future = std::pin::pin!(cloned.send([1; 20]));
881 let mut future = Cassette::new(future);
882 if future.poll_on().is_none() {
883 future.poll_on();
884 }
885
886 drop(rx);
887
888 handle1.join().unwrap();
889 });
890 }
891}