rtic_common/
wait_queue.rs1use core::{
4    future::poll_fn,
5    marker::PhantomPinned,
6    pin::{pin, Pin},
7    ptr::null_mut,
8    task::{Poll, Waker},
9};
10use critical_section as cs;
11use portable_atomic::{AtomicBool, AtomicPtr, Ordering};
12
13use crate::dropper::OnDropWith;
14
15pub type WaitQueue = DoublyLinkedList<Waker>;
17
18pub struct DoublyLinkedList<T> {
23    head: AtomicPtr<Link<T>>, tail: AtomicPtr<Link<T>>,
25}
26
27impl<T> DoublyLinkedList<T> {
28    pub const fn new() -> Self {
30        Self {
31            head: AtomicPtr::new(null_mut()),
32            tail: AtomicPtr::new(null_mut()),
33        }
34    }
35}
36
37impl<T> Default for DoublyLinkedList<T> {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl<T: Clone> DoublyLinkedList<T> {
44    const R: Ordering = Ordering::Relaxed;
45
46    pub fn pop(&self) -> Option<T> {
48        cs::with(|_| {
49            core::sync::atomic::fence(Ordering::SeqCst);
51
52            let head = self.head.load(Self::R);
53
54            if let Some(head_ref) = unsafe { head.as_ref() } {
56                self.head.store(head_ref.next.load(Self::R), Self::R);
58
59                let head_val = head_ref.val.clone();
61
62                let tail = self.tail.load(Self::R);
63                if head == tail {
64                    self.tail.store(null_mut(), Self::R);
66                }
67
68                if let Some(next_ref) = unsafe { head_ref.next.load(Self::R).as_ref() } {
69                    next_ref.prev.store(null_mut(), Self::R);
70                }
71
72                head_ref.next.store(null_mut(), Self::R);
74                head_ref.prev.store(null_mut(), Self::R);
75                head_ref.is_popped.store(true, Self::R);
76
77                return Some(head_val);
78            }
79
80            None
81        })
82    }
83
84    pub unsafe fn push(&self, link: Pin<&Link<T>>) {
90        cs::with(|_| {
91            core::sync::atomic::fence(Ordering::SeqCst);
93
94            let tail = self.tail.load(Self::R);
95
96            let link = link.get_ref();
98
99            if let Some(tail_ref) = unsafe { tail.as_ref() } {
100                link.prev.store(tail, Self::R);
102                self.tail.store(link as *const _ as *mut _, Self::R);
103                tail_ref.next.store(link as *const _ as *mut _, Self::R);
104            } else {
105                self.tail.store(link as *const _ as *mut _, Self::R);
107                self.head.store(link as *const _ as *mut _, Self::R);
108            }
109        });
110    }
111
112    pub fn is_empty(&self) -> bool {
114        self.head.load(Self::R).is_null()
115    }
116}
117
118pub struct Link<T> {
120    pub(crate) val: T,
121    next: AtomicPtr<Link<T>>,
122    prev: AtomicPtr<Link<T>>,
123    is_popped: AtomicBool,
124    _up: PhantomPinned,
125}
126
127impl<T: Clone> Link<T> {
128    const R: Ordering = Ordering::Relaxed;
129
130    pub const fn new(val: T) -> Self {
132        Self {
133            val,
134            next: AtomicPtr::new(null_mut()),
135            prev: AtomicPtr::new(null_mut()),
136            is_popped: AtomicBool::new(false),
137            _up: PhantomPinned,
138        }
139    }
140
141    pub fn is_popped(&self) -> bool {
143        self.is_popped.load(Self::R)
144    }
145
146    pub fn remove_from_list(&self, list: &DoublyLinkedList<T>) {
148        cs::with(|_| {
149            core::sync::atomic::fence(Ordering::SeqCst);
151
152            if self.is_popped() {
153                return;
154            }
155
156            let prev = self.prev.load(Self::R);
157            let next = self.next.load(Self::R);
158            self.is_popped.store(true, Self::R);
159
160            match unsafe { (prev.as_ref(), next.as_ref()) } {
161                (None, None) => {
162                    let sp = self as *const _;
164
165                    if sp == list.head.load(Ordering::Relaxed) {
166                        list.head.store(null_mut(), Self::R);
167                        list.tail.store(null_mut(), Self::R);
168                    }
169                }
170                (None, Some(next_ref)) => {
171                    next_ref.prev.store(null_mut(), Self::R);
173                    list.head.store(next, Self::R);
174                }
175                (Some(prev_ref), None) => {
176                    prev_ref.next.store(null_mut(), Self::R);
178                    list.tail.store(prev, Self::R);
179                }
180                (Some(prev_ref), Some(next_ref)) => {
181                    prev_ref.next.store(next, Self::R);
185                    next_ref.prev.store(prev, Self::R);
186                }
187            }
188        })
189    }
190}
191
192#[cfg(test)]
193impl<T: core::fmt::Debug + Clone> DoublyLinkedList<T> {
194    fn print(&self) {
195        cs::with(|_| {
196            core::sync::atomic::fence(Ordering::SeqCst);
198
199            let mut head = self.head.load(Self::R);
200            let tail = self.tail.load(Self::R);
201
202            println!(
203                "List - h = 0x{:x}, t = 0x{:x}",
204                head as usize, tail as usize
205            );
206
207            let mut i = 0;
208
209            while let Some(head_ref) = unsafe { head.as_ref() } {
211                println!(
212                    "    {}: {:?}, s = 0x{:x}, n = 0x{:x}, p = 0x{:x}",
213                    i,
214                    head_ref.val,
215                    head as usize,
216                    head_ref.next.load(Ordering::Relaxed) as usize,
217                    head_ref.prev.load(Ordering::Relaxed) as usize
218                );
219
220                head = head_ref.next.load(Self::R);
221
222                i += 1;
223            }
224        });
225    }
226}
227
228impl DoublyLinkedList<Waker> {
229    pub async fn wait_until<T, F: FnMut() -> Option<T>>(&self, mut f: F) -> T {
231        let link_place = pin!(None::<Link<Waker>>);
232
233        let mut link_guard = OnDropWith::new(link_place, |link| {
234            if let Some(link) = link.as_ref().as_pin_ref() {
235                link.remove_from_list(self);
236            }
237            link.set(None);
238        });
239
240        poll_fn(move |cx| {
241            link_guard.execute();
245
246            if let Some(val) = f() {
247                return Poll::Ready(val);
248            }
249
250            let new_link = Link::new(cx.waker().clone());
256
257            link_guard.set(Some(new_link));
259
260            let new_link_pinned = link_guard.as_ref().as_pin_ref().expect("We just set it");
261
262            unsafe { self.push(new_link_pinned) };
266
267            Poll::Pending
268        })
269        .await
270    }
271}
272
273#[cfg(test)]
274impl<T: core::fmt::Debug + Clone> Link<T> {
275    fn print(&self) {
276        cs::with(|_| {
277            core::sync::atomic::fence(Ordering::SeqCst);
279
280            println!("Link:");
281
282            println!(
283                "    val = {:?}, n = 0x{:x}, p = 0x{:x}",
284                self.val,
285                self.next.load(Ordering::Relaxed) as usize,
286                self.prev.load(Ordering::Relaxed) as usize
287            );
288        });
289    }
290}
291
292mod compile_fail_test {}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn linked_list() {
316        let wq = DoublyLinkedList::<u32>::new();
317
318        let i1 = Link::new(10);
319        let i2 = Link::new(11);
320        let i3 = Link::new(12);
321        let i4 = Link::new(13);
322        let i5 = Link::new(14);
323
324        unsafe { wq.push(Pin::new_unchecked(&i1)) };
325        unsafe { wq.push(Pin::new_unchecked(&i2)) };
326        unsafe { wq.push(Pin::new_unchecked(&i3)) };
327        unsafe { wq.push(Pin::new_unchecked(&i4)) };
328        unsafe { wq.push(Pin::new_unchecked(&i5)) };
329
330        wq.print();
331
332        wq.pop();
333        i1.print();
334
335        wq.print();
336
337        i4.remove_from_list(&wq);
338
339        wq.print();
340
341        println!("i2");
345        i2.remove_from_list(&wq);
346        wq.print();
347
348        println!("i3");
349        i3.remove_from_list(&wq);
350        wq.print();
351
352        println!("i5");
353        i5.remove_from_list(&wq);
354        wq.print();
355    }
356}