rtic_sync/arbiter.rs
1//! A Mutex-like FIFO with unlimited-waiter for embedded systems.
2//!
3//! Example usage:
4//!
5//! ```rust
6//! # async fn select<F1, F2>(f1: F1, f2: F2) {}
7//! use rtic_sync::arbiter::Arbiter;
8//!
9//! // Instantiate an Arbiter with a static lifetime.
10//! static ARBITER: Arbiter<u32> = Arbiter::new(32);
11//!
12//! async fn run(){
13//! let write_42 = async move {
14//! *ARBITER.access().await = 42;
15//! };
16//!
17//! let write_1337 = async move {
18//! *ARBITER.access().await = 1337;
19//! };
20//!
21//! // Attempt to access the Arbiter concurrently.
22//! select(write_42, write_1337).await;
23//! }
24//! ```
25
26use core::cell::UnsafeCell;
27use core::future::poll_fn;
28use core::ops::{Deref, DerefMut};
29use core::pin::Pin;
30use core::task::{Poll, Waker};
31use portable_atomic::{fence, AtomicBool, Ordering};
32use rtic_common::dropper::OnDrop;
33use rtic_common::wait_queue::{Link, WaitQueue};
34
35pub mod i2c;
36pub mod spi;
37
38/// This is needed to make the async closure in `send` accept that we "share"
39/// the link possible between threads.
40#[derive(Clone)]
41struct LinkPtr(*mut Option<Link<Waker>>);
42
43impl LinkPtr {
44 /// This will dereference the pointer stored within and give out an `&mut`.
45 unsafe fn get(&mut self) -> &mut Option<Link<Waker>> {
46 &mut *self.0
47 }
48}
49
50unsafe impl Send for LinkPtr {}
51unsafe impl Sync for LinkPtr {}
52
53/// An FIFO waitqueue for use in shared bus usecases.
54pub struct Arbiter<T> {
55 wait_queue: WaitQueue,
56 inner: UnsafeCell<T>,
57 taken: AtomicBool,
58}
59
60unsafe impl<T> Send for Arbiter<T> {}
61unsafe impl<T> Sync for Arbiter<T> {}
62
63impl<T> Arbiter<T> {
64 /// Create a new arbiter.
65 pub const fn new(inner: T) -> Self {
66 Self {
67 wait_queue: WaitQueue::new(),
68 inner: UnsafeCell::new(inner),
69 taken: AtomicBool::new(false),
70 }
71 }
72
73 /// Get access to the inner value in the [`Arbiter`]. This will wait until access is granted,
74 /// for non-blocking access use `try_access`.
75 pub async fn access(&self) -> ExclusiveAccess<'_, T> {
76 let mut link_ptr: Option<Link<Waker>> = None;
77
78 // Make this future `Drop`-safe.
79 // SAFETY(link_ptr): Shadow the original definition of `link_ptr` so we can't abuse it.
80 let mut link_ptr = LinkPtr(&mut link_ptr as *mut Option<Link<Waker>>);
81
82 let mut link_ptr2 = link_ptr.clone();
83 let dropper = OnDrop::new(|| {
84 // SAFETY: We only run this closure and dereference the pointer if we have
85 // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference
86 // of this pointer is in the `poll_fn`.
87 if let Some(link) = unsafe { link_ptr2.get() } {
88 link.remove_from_list(&self.wait_queue);
89 }
90 });
91
92 poll_fn(|cx| {
93 critical_section::with(|_| {
94 fence(Ordering::SeqCst);
95
96 // The queue is empty and noone has taken the value.
97 if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) {
98 self.taken.store(true, Ordering::Relaxed);
99
100 return Poll::Ready(());
101 }
102
103 // SAFETY: This pointer is only dereferenced here and on drop of the future
104 // which happens outside this `poll_fn`'s stack frame.
105 let link = unsafe { link_ptr.get() };
106 if let Some(link) = link {
107 if link.is_popped() {
108 return Poll::Ready(());
109 }
110 } else {
111 // Place the link in the wait queue on first run.
112 let link_ref = link.insert(Link::new(cx.waker().clone()));
113
114 // SAFETY(new_unchecked): The address to the link is stable as it is defined
115 // outside this stack frame.
116 // SAFETY(push): `link_ref` lifetime comes from `link_ptr` that is shadowed,
117 // and we make sure in `dropper` that the link is removed from the queue
118 // before dropping `link_ptr` AND `dropper` makes sure that the shadowed
119 // `link_ptr` lives until the end of the stack frame.
120 unsafe { self.wait_queue.push(Pin::new_unchecked(link_ref)) };
121 }
122
123 Poll::Pending
124 })
125 })
126 .await;
127
128 // Make sure the link is removed from the queue.
129 drop(dropper);
130
131 // SAFETY: One only gets here if there is exlusive access.
132 ExclusiveAccess {
133 arbiter: self,
134 inner: unsafe { &mut *self.inner.get() },
135 }
136 }
137
138 /// Non-blockingly tries to access the underlying value.
139 /// If someone is in queue to get it, this will return `None`.
140 pub fn try_access(&self) -> Option<ExclusiveAccess<'_, T>> {
141 critical_section::with(|_| {
142 fence(Ordering::SeqCst);
143
144 // The queue is empty and noone has taken the value.
145 if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) {
146 self.taken.store(true, Ordering::Relaxed);
147
148 // SAFETY: One only gets here if there is exlusive access.
149 Some(ExclusiveAccess {
150 arbiter: self,
151 inner: unsafe { &mut *self.inner.get() },
152 })
153 } else {
154 None
155 }
156 })
157 }
158}
159
160/// This token represents exclusive access to the value protected by the [`Arbiter`].
161pub struct ExclusiveAccess<'a, T> {
162 arbiter: &'a Arbiter<T>,
163 inner: &'a mut T,
164}
165
166impl<T> Drop for ExclusiveAccess<'_, T> {
167 fn drop(&mut self) {
168 critical_section::with(|_| {
169 fence(Ordering::SeqCst);
170
171 if self.arbiter.wait_queue.is_empty() {
172 // If noone is in queue and we release exclusive access, reset `taken`.
173 self.arbiter.taken.store(false, Ordering::Relaxed);
174 } else if let Some(next) = self.arbiter.wait_queue.pop() {
175 // Wake the next one in queue.
176 next.wake();
177 }
178 })
179 }
180}
181
182impl<T> Deref for ExclusiveAccess<'_, T> {
183 type Target = T;
184
185 fn deref(&self) -> &Self::Target {
186 self.inner
187 }
188}
189
190impl<T> DerefMut for ExclusiveAccess<'_, T> {
191 fn deref_mut(&mut self) -> &mut Self::Target {
192 self.inner
193 }
194}
195
196#[cfg(not(loom))]
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[tokio::test]
202 async fn stress_channel() {
203 const NUM_RUNS: usize = 100_000;
204
205 static ARB: Arbiter<usize> = Arbiter::new(0);
206 let mut v = std::vec::Vec::new();
207
208 for _ in 0..NUM_RUNS {
209 v.push(tokio::spawn(async move {
210 *ARB.access().await += 1;
211 }));
212 }
213
214 for v in v {
215 v.await.unwrap();
216 }
217
218 assert_eq!(*ARB.access().await, NUM_RUNS)
219 }
220}