rtic_sync/
arbiter.rs

1//! A Mutex-like FIFO with unlimited-waiter for embedded systems.
2//!
3//! Example usage:
4//!
5//! ```rust
6//! # async fn select<F1, F2>(f1: F1, f2: F2) {}
7//! use rtic_sync::arbiter::Arbiter;
8//!
9//! // Instantiate an Arbiter with a static lifetime.
10//! static ARBITER: Arbiter<u32> = Arbiter::new(32);
11//!
12//! async fn run(){
13//!     let write_42 = async move {
14//!         *ARBITER.access().await = 42;
15//!     };
16//!
17//!     let write_1337 = async move {
18//!         *ARBITER.access().await = 1337;
19//!     };
20//!
21//!     // Attempt to access the Arbiter concurrently.
22//!     select(write_42, write_1337).await;
23//! }
24//! ```
25
26use core::cell::UnsafeCell;
27use core::future::poll_fn;
28use core::ops::{Deref, DerefMut};
29use core::pin::Pin;
30use core::task::{Poll, Waker};
31use portable_atomic::{fence, AtomicBool, Ordering};
32
33use rtic_common::dropper::OnDrop;
34use rtic_common::wait_queue::{Link, WaitQueue};
35
36/// This is needed to make the async closure in `send` accept that we "share"
37/// the link possible between threads.
38#[derive(Clone)]
39struct LinkPtr(*mut Option<Link<Waker>>);
40
41impl LinkPtr {
42    /// This will dereference the pointer stored within and give out an `&mut`.
43    unsafe fn get(&mut self) -> &mut Option<Link<Waker>> {
44        &mut *self.0
45    }
46}
47
48unsafe impl Send for LinkPtr {}
49unsafe impl Sync for LinkPtr {}
50
51/// An FIFO waitqueue for use in shared bus usecases.
52pub struct Arbiter<T> {
53    wait_queue: WaitQueue,
54    inner: UnsafeCell<T>,
55    taken: AtomicBool,
56}
57
58unsafe impl<T> Send for Arbiter<T> {}
59unsafe impl<T> Sync for Arbiter<T> {}
60
61impl<T> Arbiter<T> {
62    /// Create a new arbiter.
63    pub const fn new(inner: T) -> Self {
64        Self {
65            wait_queue: WaitQueue::new(),
66            inner: UnsafeCell::new(inner),
67            taken: AtomicBool::new(false),
68        }
69    }
70
71    /// Get access to the inner value in the [`Arbiter`]. This will wait until access is granted,
72    /// for non-blocking access use `try_access`.
73    pub async fn access(&self) -> ExclusiveAccess<'_, T> {
74        let mut link_ptr: Option<Link<Waker>> = None;
75
76        // Make this future `Drop`-safe.
77        // SAFETY(link_ptr): Shadow the original definition of `link_ptr` so we can't abuse it.
78        let mut link_ptr = LinkPtr(&mut link_ptr as *mut Option<Link<Waker>>);
79
80        let mut link_ptr2 = link_ptr.clone();
81        let dropper = OnDrop::new(|| {
82            // SAFETY: We only run this closure and dereference the pointer if we have
83            // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference
84            // of this pointer is in the `poll_fn`.
85            if let Some(link) = unsafe { link_ptr2.get() } {
86                link.remove_from_list(&self.wait_queue);
87            }
88        });
89
90        poll_fn(|cx| {
91            critical_section::with(|_| {
92                fence(Ordering::SeqCst);
93
94                // The queue is empty and noone has taken the value.
95                if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) {
96                    self.taken.store(true, Ordering::Relaxed);
97
98                    return Poll::Ready(());
99                }
100
101                // SAFETY: This pointer is only dereferenced here and on drop of the future
102                // which happens outside this `poll_fn`'s stack frame.
103                let link = unsafe { link_ptr.get() };
104                if let Some(link) = link {
105                    if link.is_popped() {
106                        return Poll::Ready(());
107                    }
108                } else {
109                    // Place the link in the wait queue on first run.
110                    let link_ref = link.insert(Link::new(cx.waker().clone()));
111
112                    // SAFETY(new_unchecked): The address to the link is stable as it is defined
113                    // outside this stack frame.
114                    // SAFETY(push): `link_ref` lifetime comes from `link_ptr` that is shadowed,
115                    // and  we make sure in `dropper` that the link is removed from the queue
116                    // before dropping `link_ptr` AND `dropper` makes sure that the shadowed
117                    // `link_ptr` lives until the end of the stack frame.
118                    unsafe { self.wait_queue.push(Pin::new_unchecked(link_ref)) };
119                }
120
121                Poll::Pending
122            })
123        })
124        .await;
125
126        // Make sure the link is removed from the queue.
127        drop(dropper);
128
129        // SAFETY: One only gets here if there is exlusive access.
130        ExclusiveAccess {
131            arbiter: self,
132            inner: unsafe { &mut *self.inner.get() },
133        }
134    }
135
136    /// Non-blockingly tries to access the underlying value.
137    /// If someone is in queue to get it, this will return `None`.
138    pub fn try_access(&self) -> Option<ExclusiveAccess<'_, T>> {
139        critical_section::with(|_| {
140            fence(Ordering::SeqCst);
141
142            // The queue is empty and noone has taken the value.
143            if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) {
144                self.taken.store(true, Ordering::Relaxed);
145
146                // SAFETY: One only gets here if there is exlusive access.
147                Some(ExclusiveAccess {
148                    arbiter: self,
149                    inner: unsafe { &mut *self.inner.get() },
150                })
151            } else {
152                None
153            }
154        })
155    }
156}
157
158/// This token represents exclusive access to the value protected by the [`Arbiter`].
159pub struct ExclusiveAccess<'a, T> {
160    arbiter: &'a Arbiter<T>,
161    inner: &'a mut T,
162}
163
164impl<T> Drop for ExclusiveAccess<'_, T> {
165    fn drop(&mut self) {
166        critical_section::with(|_| {
167            fence(Ordering::SeqCst);
168
169            if self.arbiter.wait_queue.is_empty() {
170                // If noone is in queue and we release exclusive access, reset `taken`.
171                self.arbiter.taken.store(false, Ordering::Relaxed);
172            } else if let Some(next) = self.arbiter.wait_queue.pop() {
173                // Wake the next one in queue.
174                next.wake();
175            }
176        })
177    }
178}
179
180impl<T> Deref for ExclusiveAccess<'_, T> {
181    type Target = T;
182
183    fn deref(&self) -> &Self::Target {
184        self.inner
185    }
186}
187
188impl<T> DerefMut for ExclusiveAccess<'_, T> {
189    fn deref_mut(&mut self) -> &mut Self::Target {
190        self.inner
191    }
192}
193
194/// SPI bus sharing using [`Arbiter`]
195pub mod spi {
196    use super::Arbiter;
197    use embedded_hal::digital::OutputPin;
198    use embedded_hal_async::{
199        delay::DelayNs,
200        spi::{ErrorType, Operation, SpiBus, SpiDevice},
201    };
202    use embedded_hal_bus::spi::DeviceError;
203
204    /// [`Arbiter`]-based shared bus implementation.
205    pub struct ArbiterDevice<'a, BUS, CS, D> {
206        bus: &'a Arbiter<BUS>,
207        cs: CS,
208        delay: D,
209    }
210
211    impl<'a, BUS, CS, D> ArbiterDevice<'a, BUS, CS, D> {
212        /// Create a new [`ArbiterDevice`].
213        pub fn new(bus: &'a Arbiter<BUS>, cs: CS, delay: D) -> Self {
214            Self { bus, cs, delay }
215        }
216    }
217
218    impl<BUS, CS, D> ErrorType for ArbiterDevice<'_, BUS, CS, D>
219    where
220        BUS: ErrorType,
221        CS: OutputPin,
222    {
223        type Error = DeviceError<BUS::Error, CS::Error>;
224    }
225
226    impl<Word, BUS, CS, D> SpiDevice<Word> for ArbiterDevice<'_, BUS, CS, D>
227    where
228        Word: Copy + 'static,
229        BUS: SpiBus<Word>,
230        CS: OutputPin,
231        D: DelayNs,
232    {
233        async fn transaction(
234            &mut self,
235            operations: &mut [Operation<'_, Word>],
236        ) -> Result<(), DeviceError<BUS::Error, CS::Error>> {
237            let mut bus = self.bus.access().await;
238
239            self.cs.set_low().map_err(DeviceError::Cs)?;
240
241            let op_res = 'ops: {
242                for op in operations {
243                    let res = match op {
244                        Operation::Read(buf) => bus.read(buf).await,
245                        Operation::Write(buf) => bus.write(buf).await,
246                        Operation::Transfer(read, write) => bus.transfer(read, write).await,
247                        Operation::TransferInPlace(buf) => bus.transfer_in_place(buf).await,
248                        Operation::DelayNs(ns) => match bus.flush().await {
249                            Err(e) => Err(e),
250                            Ok(()) => {
251                                self.delay.delay_ns(*ns).await;
252                                Ok(())
253                            }
254                        },
255                    };
256                    if let Err(e) = res {
257                        break 'ops Err(e);
258                    }
259                }
260                Ok(())
261            };
262
263            // On failure, it's important to still flush and deassert CS.
264            let flush_res = bus.flush().await;
265            let cs_res = self.cs.set_high();
266
267            op_res.map_err(DeviceError::Spi)?;
268            flush_res.map_err(DeviceError::Spi)?;
269            cs_res.map_err(DeviceError::Cs)?;
270
271            Ok(())
272        }
273    }
274}
275
276/// I2C bus sharing using [`Arbiter`]
277///
278/// An Example how to use it in RTIC application:
279/// ```text
280/// #[app(device = some_hal, peripherals = true, dispatchers = [TIM16])]
281/// mod app {
282///     use core::mem::MaybeUninit;
283///     use rtic_sync::{arbiter::{i2c::ArbiterDevice, Arbiter},
284///
285///     #[shared]
286///     struct Shared {}
287///
288///     #[local]
289///     struct Local {
290///         ens160: Ens160<ArbiterDevice<'static, I2c<'static, I2C1>>>,
291///     }
292///
293///     #[init(local = [
294///         i2c_arbiter: MaybeUninit<Arbiter<I2c<'static, I2C1>>> = MaybeUninit::uninit(),
295///     ])]
296///     fn init(cx: init::Context) -> (Shared, Local) {
297///         let i2c = I2c::new(cx.device.I2C1);
298///         let i2c_arbiter = cx.local.i2c_arbiter.write(Arbiter::new(i2c));
299///         let ens160 = Ens160::new(ArbiterDevice::new(i2c_arbiter), 0x52);
300///
301///         i2c_sensors::spawn(i2c_arbiter).ok();
302///
303///         (Shared {}, Local { ens160 })
304///     }
305///
306///     #[task(local = [ens160])]
307///     async fn i2c_sensors(cx: i2c_sensors::Context, i2c: &'static Arbiter<I2c<'static, I2C1>>) {
308///         use sensor::Asensor;
309///
310///         loop {
311///             // Use scope to make sure I2C access is dropped.
312///             {
313///                 // Read from sensor driver that wants to use I2C directly.
314///                 let mut i2c = i2c.access().await;
315///                 let status = Asensor::status(&mut i2c).await;
316///             }
317///
318///             // Read ENS160 sensor.
319///             let eco2 = cx.local.ens160.eco2().await;
320///         }
321///     }
322/// }
323/// ```
324pub mod i2c {
325    use super::Arbiter;
326    use embedded_hal::i2c::{AddressMode, ErrorType, Operation};
327    use embedded_hal_async::i2c::I2c;
328
329    /// [`Arbiter`]-based shared bus implementation for I2C.
330    pub struct ArbiterDevice<'a, BUS> {
331        bus: &'a Arbiter<BUS>,
332    }
333
334    impl<'a, BUS> ArbiterDevice<'a, BUS> {
335        /// Create a new [`ArbiterDevice`] for I2C.
336        pub fn new(bus: &'a Arbiter<BUS>) -> Self {
337            Self { bus }
338        }
339    }
340
341    impl<BUS> ErrorType for ArbiterDevice<'_, BUS>
342    where
343        BUS: ErrorType,
344    {
345        type Error = BUS::Error;
346    }
347
348    impl<BUS, A> I2c<A> for ArbiterDevice<'_, BUS>
349    where
350        BUS: I2c<A>,
351        A: AddressMode,
352    {
353        async fn read(&mut self, address: A, read: &mut [u8]) -> Result<(), Self::Error> {
354            let mut bus = self.bus.access().await;
355            bus.read(address, read).await
356        }
357
358        async fn write(&mut self, address: A, write: &[u8]) -> Result<(), Self::Error> {
359            let mut bus = self.bus.access().await;
360            bus.write(address, write).await
361        }
362
363        async fn write_read(
364            &mut self,
365            address: A,
366            write: &[u8],
367            read: &mut [u8],
368        ) -> Result<(), Self::Error> {
369            let mut bus = self.bus.access().await;
370            bus.write_read(address, write, read).await
371        }
372
373        async fn transaction(
374            &mut self,
375            address: A,
376            operations: &mut [Operation<'_>],
377        ) -> Result<(), Self::Error> {
378            let mut bus = self.bus.access().await;
379            bus.transaction(address, operations).await
380        }
381    }
382}
383
384#[cfg(not(loom))]
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[tokio::test]
390    async fn stress_channel() {
391        const NUM_RUNS: usize = 100_000;
392
393        static ARB: Arbiter<usize> = Arbiter::new(0);
394        let mut v = std::vec::Vec::new();
395
396        for _ in 0..NUM_RUNS {
397            v.push(tokio::spawn(async move {
398                *ARB.access().await += 1;
399            }));
400        }
401
402        for v in v {
403            v.await.unwrap();
404        }
405
406        assert_eq!(*ARB.access().await, NUM_RUNS)
407    }
408}