rtic_common/
wait_queue.rs

1//! A wait queue implementation using a doubly linked list.
2
3use core::{
4    future::poll_fn,
5    marker::PhantomPinned,
6    pin::{pin, Pin},
7    ptr::null_mut,
8    task::{Poll, Waker},
9};
10use critical_section as cs;
11use portable_atomic::{AtomicBool, AtomicPtr, Ordering};
12
13use crate::dropper::OnDropWith;
14
15/// A helper definition of a wait queue.
16pub type WaitQueue = DoublyLinkedList<Waker>;
17
18/// An atomic, doubly linked, FIFO list for a wait queue.
19///
20/// Atomicity is guaranteed by short [`critical_section`]s, so this list is _not_ lock free,
21/// but it will not deadlock.
22pub struct DoublyLinkedList<T> {
23    head: AtomicPtr<Link<T>>, // UnsafeCell<*mut Link<T>>
24    tail: AtomicPtr<Link<T>>,
25}
26
27impl<T> DoublyLinkedList<T> {
28    /// Create a new linked list.
29    pub const fn new() -> Self {
30        Self {
31            head: AtomicPtr::new(null_mut()),
32            tail: AtomicPtr::new(null_mut()),
33        }
34    }
35}
36
37impl<T> Default for DoublyLinkedList<T> {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl<T: Clone> DoublyLinkedList<T> {
44    const R: Ordering = Ordering::Relaxed;
45
46    /// Pop the first element in the queue.
47    pub fn pop(&self) -> Option<T> {
48        cs::with(|_| {
49            // Make sure all previous writes are visible
50            core::sync::atomic::fence(Ordering::SeqCst);
51
52            let head = self.head.load(Self::R);
53
54            // SAFETY: `as_ref` is safe as `insert` requires a valid reference to a link
55            if let Some(head_ref) = unsafe { head.as_ref() } {
56                // Move head to the next element
57                self.head.store(head_ref.next.load(Self::R), Self::R);
58
59                // We read the value at head
60                let head_val = head_ref.val.clone();
61
62                let tail = self.tail.load(Self::R);
63                if head == tail {
64                    // The queue is empty
65                    self.tail.store(null_mut(), Self::R);
66                }
67
68                if let Some(next_ref) = unsafe { head_ref.next.load(Self::R).as_ref() } {
69                    next_ref.prev.store(null_mut(), Self::R);
70                }
71
72                // Clear the pointers in the node.
73                head_ref.next.store(null_mut(), Self::R);
74                head_ref.prev.store(null_mut(), Self::R);
75                head_ref.is_popped.store(true, Self::R);
76
77                return Some(head_val);
78            }
79
80            None
81        })
82    }
83
84    /// Put an element at the back of the queue.
85    ///
86    /// # Safety
87    ///
88    /// The link must live until it is removed from the queue.
89    pub unsafe fn push(&self, link: Pin<&Link<T>>) {
90        cs::with(|_| {
91            // Make sure all previous writes are visible
92            core::sync::atomic::fence(Ordering::SeqCst);
93
94            let tail = self.tail.load(Self::R);
95
96            // SAFETY: This datastructure does not move the underlying value.
97            let link = link.get_ref();
98
99            if let Some(tail_ref) = unsafe { tail.as_ref() } {
100                // Queue is not empty
101                link.prev.store(tail, Self::R);
102                self.tail.store(link as *const _ as *mut _, Self::R);
103                tail_ref.next.store(link as *const _ as *mut _, Self::R);
104            } else {
105                // Queue is empty
106                self.tail.store(link as *const _ as *mut _, Self::R);
107                self.head.store(link as *const _ as *mut _, Self::R);
108            }
109        });
110    }
111
112    /// Check if the queue is empty.
113    pub fn is_empty(&self) -> bool {
114        self.head.load(Self::R).is_null()
115    }
116}
117
118/// A link in the linked list.
119pub struct Link<T> {
120    pub(crate) val: T,
121    next: AtomicPtr<Link<T>>,
122    prev: AtomicPtr<Link<T>>,
123    is_popped: AtomicBool,
124    _up: PhantomPinned,
125}
126
127impl<T: Clone> Link<T> {
128    const R: Ordering = Ordering::Relaxed;
129
130    /// Create a new link.
131    pub const fn new(val: T) -> Self {
132        Self {
133            val,
134            next: AtomicPtr::new(null_mut()),
135            prev: AtomicPtr::new(null_mut()),
136            is_popped: AtomicBool::new(false),
137            _up: PhantomPinned,
138        }
139    }
140
141    /// Return true if this link has been poped from the list.
142    pub fn is_popped(&self) -> bool {
143        self.is_popped.load(Self::R)
144    }
145
146    /// Remove this link from a linked list.
147    pub fn remove_from_list(&self, list: &DoublyLinkedList<T>) {
148        cs::with(|_| {
149            // Make sure all previous writes are visible
150            core::sync::atomic::fence(Ordering::SeqCst);
151
152            if self.is_popped() {
153                return;
154            }
155
156            let prev = self.prev.load(Self::R);
157            let next = self.next.load(Self::R);
158            self.is_popped.store(true, Self::R);
159
160            match unsafe { (prev.as_ref(), next.as_ref()) } {
161                (None, None) => {
162                    // Not in the list or alone in the list, check if list head == node address
163                    let sp = self as *const _;
164
165                    if sp == list.head.load(Ordering::Relaxed) {
166                        list.head.store(null_mut(), Self::R);
167                        list.tail.store(null_mut(), Self::R);
168                    }
169                }
170                (None, Some(next_ref)) => {
171                    // First in the list
172                    next_ref.prev.store(null_mut(), Self::R);
173                    list.head.store(next, Self::R);
174                }
175                (Some(prev_ref), None) => {
176                    // Last in the list
177                    prev_ref.next.store(null_mut(), Self::R);
178                    list.tail.store(prev, Self::R);
179                }
180                (Some(prev_ref), Some(next_ref)) => {
181                    // Somewhere in the list
182
183                    // Connect the `prev.next` and `next.prev` with each other to remove the node
184                    prev_ref.next.store(next, Self::R);
185                    next_ref.prev.store(prev, Self::R);
186                }
187            }
188        })
189    }
190}
191
192#[cfg(test)]
193impl<T: core::fmt::Debug + Clone> DoublyLinkedList<T> {
194    fn print(&self) {
195        cs::with(|_| {
196            // Make sure all previous writes are visible
197            core::sync::atomic::fence(Ordering::SeqCst);
198
199            let mut head = self.head.load(Self::R);
200            let tail = self.tail.load(Self::R);
201
202            println!(
203                "List - h = 0x{:x}, t = 0x{:x}",
204                head as usize, tail as usize
205            );
206
207            let mut i = 0;
208
209            // SAFETY: `as_ref` is safe as `insert` requires a valid reference to a link
210            while let Some(head_ref) = unsafe { head.as_ref() } {
211                println!(
212                    "    {}: {:?}, s = 0x{:x}, n = 0x{:x}, p = 0x{:x}",
213                    i,
214                    head_ref.val,
215                    head as usize,
216                    head_ref.next.load(Ordering::Relaxed) as usize,
217                    head_ref.prev.load(Ordering::Relaxed) as usize
218                );
219
220                head = head_ref.next.load(Self::R);
221
222                i += 1;
223            }
224        });
225    }
226}
227
228impl DoublyLinkedList<Waker> {
229    /// Wait until `f` returns `Some`.
230    pub async fn wait_until<T, F: FnMut() -> Option<T>>(&self, mut f: F) -> T {
231        let link_place = pin!(None::<Link<Waker>>);
232
233        let mut link_guard = OnDropWith::new(link_place, |link| {
234            if let Some(link) = link.as_ref().as_pin_ref() {
235                link.remove_from_list(self);
236            }
237            link.set(None);
238        });
239
240        poll_fn(move |cx| {
241            // clean up the old link, because we are going to invalidate it.
242            // we are doing it before returning `Poll::Ready` to handle cases
243            // where the future is polled after it is completed.
244            link_guard.execute();
245
246            if let Some(val) = f() {
247                return Poll::Ready(val);
248            }
249
250            // note: we may introduce a more complex logic to try to reuse the old link
251            // with the old waker by using `Waker::will_wake` to avoid `Waker::clone`,
252            // but it is probably not needed as Rtic's `waker` is cheap to clone.
253
254            // By the contract, each poll we should update the waker.
255            let new_link = Link::new(cx.waker().clone());
256
257            // Store the link into the pinned place.
258            link_guard.set(Some(new_link));
259
260            let new_link_pinned = link_guard.as_ref().as_pin_ref().expect("We just set it");
261
262            // SAFETY: we guarantee that `link` will live until removed by cleaning it up
263            // in the destructor of the future and that destructor is guaranteed to run
264            // before it's memory is reused or invalidated because the future is pinned.
265            unsafe { self.push(new_link_pinned) };
266
267            Poll::Pending
268        })
269        .await
270    }
271}
272
273#[cfg(test)]
274impl<T: core::fmt::Debug + Clone> Link<T> {
275    fn print(&self) {
276        cs::with(|_| {
277            // Make sure all previous writes are visible
278            core::sync::atomic::fence(Ordering::SeqCst);
279
280            println!("Link:");
281
282            println!(
283                "    val = {:?}, n = 0x{:x}, p = 0x{:x}",
284                self.val,
285                self.next.load(Ordering::Relaxed) as usize,
286                self.prev.load(Ordering::Relaxed) as usize
287            );
288        });
289    }
290}
291
292/// Test that the future returned by `wait_until` is not `Unpin`.
293/// ```compile_fail
294/// fn test_unpin(list: &rtic_common::wait_queue::DoublyLinkedList<core::task::Waker>, cx: &mut core::task::Context) {
295///     let mut wait_until_future = list.wait_until(|| None::<()>);
296///     let pinned = core::pin::Pin::new(&mut wait_until_future);
297///     core::future::Future::poll(pinned, cx);
298///  }
299/// ```
300/// This test will ensure that previous test failed because of `pin`.
301/// ```
302/// fn test_unpin(list: &rtic_common::wait_queue::DoublyLinkedList<core::task::Waker>, cx: &mut core::task::Context) {
303///     let mut wait_until_future = list.wait_until(|| None::<()>);
304///     let pinned = core::pin::pin!(wait_until_future);
305///     core::future::Future::poll(pinned, cx);
306///  }
307/// ```
308mod compile_fail_test {}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn linked_list() {
316        let wq = DoublyLinkedList::<u32>::new();
317
318        let i1 = Link::new(10);
319        let i2 = Link::new(11);
320        let i3 = Link::new(12);
321        let i4 = Link::new(13);
322        let i5 = Link::new(14);
323
324        unsafe { wq.push(Pin::new_unchecked(&i1)) };
325        unsafe { wq.push(Pin::new_unchecked(&i2)) };
326        unsafe { wq.push(Pin::new_unchecked(&i3)) };
327        unsafe { wq.push(Pin::new_unchecked(&i4)) };
328        unsafe { wq.push(Pin::new_unchecked(&i5)) };
329
330        wq.print();
331
332        wq.pop();
333        i1.print();
334
335        wq.print();
336
337        i4.remove_from_list(&wq);
338
339        wq.print();
340
341        // i1.remove_from_list(&wq);
342        // wq.print();
343
344        println!("i2");
345        i2.remove_from_list(&wq);
346        wq.print();
347
348        println!("i3");
349        i3.remove_from_list(&wq);
350        wq.print();
351
352        println!("i5");
353        i5.remove_from_list(&wq);
354        wq.print();
355    }
356}