rtic_sync/arbiter.rs
1//! A Mutex-like FIFO with unlimited-waiter for embedded systems.
2//!
3//! Example usage:
4//!
5//! ```rust
6//! # async fn select<F1, F2>(f1: F1, f2: F2) {}
7//! use rtic_sync::arbiter::Arbiter;
8//!
9//! // Instantiate an Arbiter with a static lifetime.
10//! static ARBITER: Arbiter<u32> = Arbiter::new(32);
11//!
12//! async fn run(){
13//! let write_42 = async move {
14//! *ARBITER.access().await = 42;
15//! };
16//!
17//! let write_1337 = async move {
18//! *ARBITER.access().await = 1337;
19//! };
20//!
21//! // Attempt to access the Arbiter concurrently.
22//! select(write_42, write_1337).await;
23//! }
24//! ```
25
26use core::cell::UnsafeCell;
27use core::future::poll_fn;
28use core::ops::{Deref, DerefMut};
29use core::pin::Pin;
30use core::task::{Poll, Waker};
31use portable_atomic::{fence, AtomicBool, Ordering};
32
33use rtic_common::dropper::OnDrop;
34use rtic_common::wait_queue::{Link, WaitQueue};
35
36/// This is needed to make the async closure in `send` accept that we "share"
37/// the link possible between threads.
38#[derive(Clone)]
39struct LinkPtr(*mut Option<Link<Waker>>);
40
41impl LinkPtr {
42 /// This will dereference the pointer stored within and give out an `&mut`.
43 unsafe fn get(&mut self) -> &mut Option<Link<Waker>> {
44 &mut *self.0
45 }
46}
47
48unsafe impl Send for LinkPtr {}
49unsafe impl Sync for LinkPtr {}
50
51/// An FIFO waitqueue for use in shared bus usecases.
52pub struct Arbiter<T> {
53 wait_queue: WaitQueue,
54 inner: UnsafeCell<T>,
55 taken: AtomicBool,
56}
57
58unsafe impl<T> Send for Arbiter<T> {}
59unsafe impl<T> Sync for Arbiter<T> {}
60
61impl<T> Arbiter<T> {
62 /// Create a new arbiter.
63 pub const fn new(inner: T) -> Self {
64 Self {
65 wait_queue: WaitQueue::new(),
66 inner: UnsafeCell::new(inner),
67 taken: AtomicBool::new(false),
68 }
69 }
70
71 /// Get access to the inner value in the [`Arbiter`]. This will wait until access is granted,
72 /// for non-blocking access use `try_access`.
73 pub async fn access(&self) -> ExclusiveAccess<'_, T> {
74 let mut link_ptr: Option<Link<Waker>> = None;
75
76 // Make this future `Drop`-safe.
77 // SAFETY(link_ptr): Shadow the original definition of `link_ptr` so we can't abuse it.
78 let mut link_ptr = LinkPtr(&mut link_ptr as *mut Option<Link<Waker>>);
79
80 let mut link_ptr2 = link_ptr.clone();
81 let dropper = OnDrop::new(|| {
82 // SAFETY: We only run this closure and dereference the pointer if we have
83 // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference
84 // of this pointer is in the `poll_fn`.
85 if let Some(link) = unsafe { link_ptr2.get() } {
86 link.remove_from_list(&self.wait_queue);
87 }
88 });
89
90 poll_fn(|cx| {
91 critical_section::with(|_| {
92 fence(Ordering::SeqCst);
93
94 // The queue is empty and noone has taken the value.
95 if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) {
96 self.taken.store(true, Ordering::Relaxed);
97
98 return Poll::Ready(());
99 }
100
101 // SAFETY: This pointer is only dereferenced here and on drop of the future
102 // which happens outside this `poll_fn`'s stack frame.
103 let link = unsafe { link_ptr.get() };
104 if let Some(link) = link {
105 if link.is_popped() {
106 return Poll::Ready(());
107 }
108 } else {
109 // Place the link in the wait queue on first run.
110 let link_ref = link.insert(Link::new(cx.waker().clone()));
111
112 // SAFETY(new_unchecked): The address to the link is stable as it is defined
113 // outside this stack frame.
114 // SAFETY(push): `link_ref` lifetime comes from `link_ptr` that is shadowed,
115 // and we make sure in `dropper` that the link is removed from the queue
116 // before dropping `link_ptr` AND `dropper` makes sure that the shadowed
117 // `link_ptr` lives until the end of the stack frame.
118 unsafe { self.wait_queue.push(Pin::new_unchecked(link_ref)) };
119 }
120
121 Poll::Pending
122 })
123 })
124 .await;
125
126 // Make sure the link is removed from the queue.
127 drop(dropper);
128
129 // SAFETY: One only gets here if there is exlusive access.
130 ExclusiveAccess {
131 arbiter: self,
132 inner: unsafe { &mut *self.inner.get() },
133 }
134 }
135
136 /// Non-blockingly tries to access the underlying value.
137 /// If someone is in queue to get it, this will return `None`.
138 pub fn try_access(&self) -> Option<ExclusiveAccess<'_, T>> {
139 critical_section::with(|_| {
140 fence(Ordering::SeqCst);
141
142 // The queue is empty and noone has taken the value.
143 if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) {
144 self.taken.store(true, Ordering::Relaxed);
145
146 // SAFETY: One only gets here if there is exlusive access.
147 Some(ExclusiveAccess {
148 arbiter: self,
149 inner: unsafe { &mut *self.inner.get() },
150 })
151 } else {
152 None
153 }
154 })
155 }
156}
157
158/// This token represents exclusive access to the value protected by the [`Arbiter`].
159pub struct ExclusiveAccess<'a, T> {
160 arbiter: &'a Arbiter<T>,
161 inner: &'a mut T,
162}
163
164impl<T> Drop for ExclusiveAccess<'_, T> {
165 fn drop(&mut self) {
166 critical_section::with(|_| {
167 fence(Ordering::SeqCst);
168
169 if self.arbiter.wait_queue.is_empty() {
170 // If noone is in queue and we release exclusive access, reset `taken`.
171 self.arbiter.taken.store(false, Ordering::Relaxed);
172 } else if let Some(next) = self.arbiter.wait_queue.pop() {
173 // Wake the next one in queue.
174 next.wake();
175 }
176 })
177 }
178}
179
180impl<T> Deref for ExclusiveAccess<'_, T> {
181 type Target = T;
182
183 fn deref(&self) -> &Self::Target {
184 self.inner
185 }
186}
187
188impl<T> DerefMut for ExclusiveAccess<'_, T> {
189 fn deref_mut(&mut self) -> &mut Self::Target {
190 self.inner
191 }
192}
193
194/// SPI bus sharing using [`Arbiter`]
195pub mod spi {
196 use super::Arbiter;
197 use embedded_hal::digital::OutputPin;
198 use embedded_hal_async::{
199 delay::DelayNs,
200 spi::{ErrorType, Operation, SpiBus, SpiDevice},
201 };
202 use embedded_hal_bus::spi::DeviceError;
203
204 /// [`Arbiter`]-based shared bus implementation.
205 pub struct ArbiterDevice<'a, BUS, CS, D> {
206 bus: &'a Arbiter<BUS>,
207 cs: CS,
208 delay: D,
209 }
210
211 impl<'a, BUS, CS, D> ArbiterDevice<'a, BUS, CS, D> {
212 /// Create a new [`ArbiterDevice`].
213 pub fn new(bus: &'a Arbiter<BUS>, cs: CS, delay: D) -> Self {
214 Self { bus, cs, delay }
215 }
216 }
217
218 impl<BUS, CS, D> ErrorType for ArbiterDevice<'_, BUS, CS, D>
219 where
220 BUS: ErrorType,
221 CS: OutputPin,
222 {
223 type Error = DeviceError<BUS::Error, CS::Error>;
224 }
225
226 impl<Word, BUS, CS, D> SpiDevice<Word> for ArbiterDevice<'_, BUS, CS, D>
227 where
228 Word: Copy + 'static,
229 BUS: SpiBus<Word>,
230 CS: OutputPin,
231 D: DelayNs,
232 {
233 async fn transaction(
234 &mut self,
235 operations: &mut [Operation<'_, Word>],
236 ) -> Result<(), DeviceError<BUS::Error, CS::Error>> {
237 let mut bus = self.bus.access().await;
238
239 self.cs.set_low().map_err(DeviceError::Cs)?;
240
241 let op_res = 'ops: {
242 for op in operations {
243 let res = match op {
244 Operation::Read(buf) => bus.read(buf).await,
245 Operation::Write(buf) => bus.write(buf).await,
246 Operation::Transfer(read, write) => bus.transfer(read, write).await,
247 Operation::TransferInPlace(buf) => bus.transfer_in_place(buf).await,
248 Operation::DelayNs(ns) => match bus.flush().await {
249 Err(e) => Err(e),
250 Ok(()) => {
251 self.delay.delay_ns(*ns).await;
252 Ok(())
253 }
254 },
255 };
256 if let Err(e) = res {
257 break 'ops Err(e);
258 }
259 }
260 Ok(())
261 };
262
263 // On failure, it's important to still flush and deassert CS.
264 let flush_res = bus.flush().await;
265 let cs_res = self.cs.set_high();
266
267 op_res.map_err(DeviceError::Spi)?;
268 flush_res.map_err(DeviceError::Spi)?;
269 cs_res.map_err(DeviceError::Cs)?;
270
271 Ok(())
272 }
273 }
274}
275
276/// I2C bus sharing using [`Arbiter`]
277///
278/// An Example how to use it in RTIC application:
279/// ```text
280/// #[app(device = some_hal, peripherals = true, dispatchers = [TIM16])]
281/// mod app {
282/// use core::mem::MaybeUninit;
283/// use rtic_sync::{arbiter::{i2c::ArbiterDevice, Arbiter},
284///
285/// #[shared]
286/// struct Shared {}
287///
288/// #[local]
289/// struct Local {
290/// ens160: Ens160<ArbiterDevice<'static, I2c<'static, I2C1>>>,
291/// }
292///
293/// #[init(local = [
294/// i2c_arbiter: MaybeUninit<Arbiter<I2c<'static, I2C1>>> = MaybeUninit::uninit(),
295/// ])]
296/// fn init(cx: init::Context) -> (Shared, Local) {
297/// let i2c = I2c::new(cx.device.I2C1);
298/// let i2c_arbiter = cx.local.i2c_arbiter.write(Arbiter::new(i2c));
299/// let ens160 = Ens160::new(ArbiterDevice::new(i2c_arbiter), 0x52);
300///
301/// i2c_sensors::spawn(i2c_arbiter).ok();
302///
303/// (Shared {}, Local { ens160 })
304/// }
305///
306/// #[task(local = [ens160])]
307/// async fn i2c_sensors(cx: i2c_sensors::Context, i2c: &'static Arbiter<I2c<'static, I2C1>>) {
308/// use sensor::Asensor;
309///
310/// loop {
311/// // Use scope to make sure I2C access is dropped.
312/// {
313/// // Read from sensor driver that wants to use I2C directly.
314/// let mut i2c = i2c.access().await;
315/// let status = Asensor::status(&mut i2c).await;
316/// }
317///
318/// // Read ENS160 sensor.
319/// let eco2 = cx.local.ens160.eco2().await;
320/// }
321/// }
322/// }
323/// ```
324pub mod i2c {
325 use super::Arbiter;
326 use embedded_hal::i2c::{AddressMode, ErrorType, Operation};
327 use embedded_hal_async::i2c::I2c;
328
329 /// [`Arbiter`]-based shared bus implementation for I2C.
330 pub struct ArbiterDevice<'a, BUS> {
331 bus: &'a Arbiter<BUS>,
332 }
333
334 impl<'a, BUS> ArbiterDevice<'a, BUS> {
335 /// Create a new [`ArbiterDevice`] for I2C.
336 pub fn new(bus: &'a Arbiter<BUS>) -> Self {
337 Self { bus }
338 }
339 }
340
341 impl<BUS> ErrorType for ArbiterDevice<'_, BUS>
342 where
343 BUS: ErrorType,
344 {
345 type Error = BUS::Error;
346 }
347
348 impl<BUS, A> I2c<A> for ArbiterDevice<'_, BUS>
349 where
350 BUS: I2c<A>,
351 A: AddressMode,
352 {
353 async fn read(&mut self, address: A, read: &mut [u8]) -> Result<(), Self::Error> {
354 let mut bus = self.bus.access().await;
355 bus.read(address, read).await
356 }
357
358 async fn write(&mut self, address: A, write: &[u8]) -> Result<(), Self::Error> {
359 let mut bus = self.bus.access().await;
360 bus.write(address, write).await
361 }
362
363 async fn write_read(
364 &mut self,
365 address: A,
366 write: &[u8],
367 read: &mut [u8],
368 ) -> Result<(), Self::Error> {
369 let mut bus = self.bus.access().await;
370 bus.write_read(address, write, read).await
371 }
372
373 async fn transaction(
374 &mut self,
375 address: A,
376 operations: &mut [Operation<'_>],
377 ) -> Result<(), Self::Error> {
378 let mut bus = self.bus.access().await;
379 bus.transaction(address, operations).await
380 }
381 }
382}
383
384#[cfg(not(loom))]
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[tokio::test]
390 async fn stress_channel() {
391 const NUM_RUNS: usize = 100_000;
392
393 static ARB: Arbiter<usize> = Arbiter::new(0);
394 let mut v = std::vec::Vec::new();
395
396 for _ in 0..NUM_RUNS {
397 v.push(tokio::spawn(async move {
398 *ARB.access().await += 1;
399 }));
400 }
401
402 for v in v {
403 v.await.unwrap();
404 }
405
406 assert_eq!(*ARB.access().await, NUM_RUNS)
407 }
408}