| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- // SPDX-License-Identifier: Apache-2.0 OR MIT
- #![allow(clippy::undocumented_unsafe_blocks)]
- #![cfg_attr(feature = "alloc", feature(allocator_api))]
- #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
- #![allow(clippy::missing_safety_doc)]
- use core::{
- cell::{Cell, UnsafeCell},
- marker::PhantomPinned,
- ops::{Deref, DerefMut},
- pin::Pin,
- sync::atomic::{AtomicBool, Ordering},
- };
- #[cfg(feature = "std")]
- use std::{
- sync::Arc,
- thread::{self, sleep, Builder, Thread},
- time::Duration,
- };
- use pin_init::*;
- #[allow(unused_attributes)]
- #[path = "./linked_list.rs"]
- pub mod linked_list;
- use linked_list::*;
- pub struct SpinLock {
- inner: AtomicBool,
- }
- impl SpinLock {
- #[inline]
- pub fn acquire(&self) -> SpinLockGuard<'_> {
- while self
- .inner
- .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
- .is_err()
- {
- #[cfg(feature = "std")]
- while self.inner.load(Ordering::Relaxed) {
- thread::yield_now();
- }
- }
- SpinLockGuard(self)
- }
- #[inline]
- #[allow(clippy::new_without_default)]
- pub const fn new() -> Self {
- Self {
- inner: AtomicBool::new(false),
- }
- }
- }
- pub struct SpinLockGuard<'a>(&'a SpinLock);
- impl Drop for SpinLockGuard<'_> {
- #[inline]
- fn drop(&mut self) {
- self.0.inner.store(false, Ordering::Release);
- }
- }
- #[pin_data]
- pub struct CMutex<T> {
- #[pin]
- wait_list: ListHead,
- spin_lock: SpinLock,
- locked: Cell<bool>,
- #[pin]
- data: UnsafeCell<T>,
- }
- impl<T> CMutex<T> {
- #[inline]
- pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
- pin_init!(CMutex {
- wait_list <- ListHead::new(),
- spin_lock: SpinLock::new(),
- locked: Cell::new(false),
- data <- unsafe {
- pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
- val.__pinned_init(slot.cast::<T>())
- })
- },
- })
- }
- #[inline]
- pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
- let mut sguard = self.spin_lock.acquire();
- if self.locked.get() {
- stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
- // println!("wait list length: {}", self.wait_list.size());
- while self.locked.get() {
- drop(sguard);
- #[cfg(feature = "std")]
- thread::park();
- sguard = self.spin_lock.acquire();
- }
- // This does have an effect, as the ListHead inside wait_entry implements Drop!
- #[expect(clippy::drop_non_drop)]
- drop(wait_entry);
- }
- self.locked.set(true);
- unsafe {
- Pin::new_unchecked(CMutexGuard {
- mtx: self,
- _pin: PhantomPinned,
- })
- }
- }
- #[allow(dead_code)]
- pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
- // SAFETY: we have an exclusive reference and thus nobody has access to data.
- unsafe { &mut *self.data.get() }
- }
- }
- unsafe impl<T: Send> Send for CMutex<T> {}
- unsafe impl<T: Send> Sync for CMutex<T> {}
- pub struct CMutexGuard<'a, T> {
- mtx: &'a CMutex<T>,
- _pin: PhantomPinned,
- }
- impl<T> Drop for CMutexGuard<'_, T> {
- #[inline]
- fn drop(&mut self) {
- let sguard = self.mtx.spin_lock.acquire();
- self.mtx.locked.set(false);
- if let Some(list_field) = self.mtx.wait_list.next() {
- let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
- #[cfg(feature = "std")]
- unsafe {
- (*_wait_entry).thread.unpark()
- };
- }
- drop(sguard);
- }
- }
- impl<T> Deref for CMutexGuard<'_, T> {
- type Target = T;
- #[inline]
- fn deref(&self) -> &Self::Target {
- unsafe { &*self.mtx.data.get() }
- }
- }
- impl<T> DerefMut for CMutexGuard<'_, T> {
- #[inline]
- fn deref_mut(&mut self) -> &mut Self::Target {
- unsafe { &mut *self.mtx.data.get() }
- }
- }
- #[pin_data]
- #[repr(C)]
- struct WaitEntry {
- #[pin]
- wait_list: ListHead,
- #[cfg(feature = "std")]
- thread: Thread,
- }
- impl WaitEntry {
- #[inline]
- fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
- #[cfg(feature = "std")]
- {
- pin_init!(Self {
- thread: thread::current(),
- wait_list <- ListHead::insert_prev(list),
- })
- }
- #[cfg(not(feature = "std"))]
- {
- pin_init!(Self {
- wait_list <- ListHead::insert_prev(list),
- })
- }
- }
- }
- #[cfg_attr(test, test)]
- #[allow(dead_code)]
- fn main() {
- #[cfg(feature = "std")]
- {
- let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
- let mut handles = vec![];
- let thread_count = 20;
- let workload = if cfg!(miri) { 100 } else { 1_000 };
- for i in 0..thread_count {
- let mtx = mtx.clone();
- handles.push(
- Builder::new()
- .name(format!("worker #{i}"))
- .spawn(move || {
- for _ in 0..workload {
- *mtx.lock() += 1;
- }
- println!("{i} halfway");
- sleep(Duration::from_millis((i as u64) * 10));
- for _ in 0..workload {
- *mtx.lock() += 1;
- }
- println!("{i} finished");
- })
- .expect("should not fail"),
- );
- }
- for h in handles {
- h.join().expect("thread panicked");
- }
- println!("{:?}", &*mtx.lock());
- assert_eq!(*mtx.lock(), workload * thread_count * 2);
- }
- }
|