mutex.rs 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. // SPDX-License-Identifier: Apache-2.0 OR MIT
  2. #![allow(clippy::undocumented_unsafe_blocks)]
  3. #![cfg_attr(feature = "alloc", feature(allocator_api))]
  4. #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
  5. #![allow(clippy::missing_safety_doc)]
  6. use core::{
  7. cell::{Cell, UnsafeCell},
  8. marker::PhantomPinned,
  9. ops::{Deref, DerefMut},
  10. pin::Pin,
  11. sync::atomic::{AtomicBool, Ordering},
  12. };
  13. #[cfg(feature = "std")]
  14. use std::{
  15. sync::Arc,
  16. thread::{self, sleep, Builder, Thread},
  17. time::Duration,
  18. };
  19. use pin_init::*;
  20. #[allow(unused_attributes)]
  21. #[path = "./linked_list.rs"]
  22. pub mod linked_list;
  23. use linked_list::*;
  24. pub struct SpinLock {
  25. inner: AtomicBool,
  26. }
  27. impl SpinLock {
  28. #[inline]
  29. pub fn acquire(&self) -> SpinLockGuard<'_> {
  30. while self
  31. .inner
  32. .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
  33. .is_err()
  34. {
  35. #[cfg(feature = "std")]
  36. while self.inner.load(Ordering::Relaxed) {
  37. thread::yield_now();
  38. }
  39. }
  40. SpinLockGuard(self)
  41. }
  42. #[inline]
  43. #[allow(clippy::new_without_default)]
  44. pub const fn new() -> Self {
  45. Self {
  46. inner: AtomicBool::new(false),
  47. }
  48. }
  49. }
  50. pub struct SpinLockGuard<'a>(&'a SpinLock);
  51. impl Drop for SpinLockGuard<'_> {
  52. #[inline]
  53. fn drop(&mut self) {
  54. self.0.inner.store(false, Ordering::Release);
  55. }
  56. }
  57. #[pin_data]
  58. pub struct CMutex<T> {
  59. #[pin]
  60. wait_list: ListHead,
  61. spin_lock: SpinLock,
  62. locked: Cell<bool>,
  63. #[pin]
  64. data: UnsafeCell<T>,
  65. }
  66. impl<T> CMutex<T> {
  67. #[inline]
  68. pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
  69. pin_init!(CMutex {
  70. wait_list <- ListHead::new(),
  71. spin_lock: SpinLock::new(),
  72. locked: Cell::new(false),
  73. data <- unsafe {
  74. pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
  75. val.__pinned_init(slot.cast::<T>())
  76. })
  77. },
  78. })
  79. }
  80. #[inline]
  81. pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
  82. let mut sguard = self.spin_lock.acquire();
  83. if self.locked.get() {
  84. stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
  85. // println!("wait list length: {}", self.wait_list.size());
  86. while self.locked.get() {
  87. drop(sguard);
  88. #[cfg(feature = "std")]
  89. thread::park();
  90. sguard = self.spin_lock.acquire();
  91. }
  92. // This does have an effect, as the ListHead inside wait_entry implements Drop!
  93. #[expect(clippy::drop_non_drop)]
  94. drop(wait_entry);
  95. }
  96. self.locked.set(true);
  97. unsafe {
  98. Pin::new_unchecked(CMutexGuard {
  99. mtx: self,
  100. _pin: PhantomPinned,
  101. })
  102. }
  103. }
  104. #[allow(dead_code)]
  105. pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
  106. // SAFETY: we have an exclusive reference and thus nobody has access to data.
  107. unsafe { &mut *self.data.get() }
  108. }
  109. }
  110. unsafe impl<T: Send> Send for CMutex<T> {}
  111. unsafe impl<T: Send> Sync for CMutex<T> {}
  112. pub struct CMutexGuard<'a, T> {
  113. mtx: &'a CMutex<T>,
  114. _pin: PhantomPinned,
  115. }
  116. impl<T> Drop for CMutexGuard<'_, T> {
  117. #[inline]
  118. fn drop(&mut self) {
  119. let sguard = self.mtx.spin_lock.acquire();
  120. self.mtx.locked.set(false);
  121. if let Some(list_field) = self.mtx.wait_list.next() {
  122. let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
  123. #[cfg(feature = "std")]
  124. unsafe {
  125. (*_wait_entry).thread.unpark()
  126. };
  127. }
  128. drop(sguard);
  129. }
  130. }
  131. impl<T> Deref for CMutexGuard<'_, T> {
  132. type Target = T;
  133. #[inline]
  134. fn deref(&self) -> &Self::Target {
  135. unsafe { &*self.mtx.data.get() }
  136. }
  137. }
  138. impl<T> DerefMut for CMutexGuard<'_, T> {
  139. #[inline]
  140. fn deref_mut(&mut self) -> &mut Self::Target {
  141. unsafe { &mut *self.mtx.data.get() }
  142. }
  143. }
  144. #[pin_data]
  145. #[repr(C)]
  146. struct WaitEntry {
  147. #[pin]
  148. wait_list: ListHead,
  149. #[cfg(feature = "std")]
  150. thread: Thread,
  151. }
  152. impl WaitEntry {
  153. #[inline]
  154. fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
  155. #[cfg(feature = "std")]
  156. {
  157. pin_init!(Self {
  158. thread: thread::current(),
  159. wait_list <- ListHead::insert_prev(list),
  160. })
  161. }
  162. #[cfg(not(feature = "std"))]
  163. {
  164. pin_init!(Self {
  165. wait_list <- ListHead::insert_prev(list),
  166. })
  167. }
  168. }
  169. }
  170. #[cfg_attr(test, test)]
  171. #[allow(dead_code)]
  172. fn main() {
  173. #[cfg(feature = "std")]
  174. {
  175. let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
  176. let mut handles = vec![];
  177. let thread_count = 20;
  178. let workload = if cfg!(miri) { 100 } else { 1_000 };
  179. for i in 0..thread_count {
  180. let mtx = mtx.clone();
  181. handles.push(
  182. Builder::new()
  183. .name(format!("worker #{i}"))
  184. .spawn(move || {
  185. for _ in 0..workload {
  186. *mtx.lock() += 1;
  187. }
  188. println!("{i} halfway");
  189. sleep(Duration::from_millis((i as u64) * 10));
  190. for _ in 0..workload {
  191. *mtx.lock() += 1;
  192. }
  193. println!("{i} finished");
  194. })
  195. .expect("should not fail"),
  196. );
  197. }
  198. for h in handles {
  199. h.join().expect("thread panicked");
  200. }
  201. println!("{:?}", &*mtx.lock());
  202. assert_eq!(*mtx.lock(), workload * thread_count * 2);
  203. }
  204. }