static_init.rs 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. // SPDX-License-Identifier: Apache-2.0 OR MIT
  2. #![allow(clippy::undocumented_unsafe_blocks)]
  3. #![cfg_attr(feature = "alloc", feature(allocator_api))]
  4. #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
  5. #![allow(unused_imports)]
  6. use core::{
  7. cell::{Cell, UnsafeCell},
  8. mem::MaybeUninit,
  9. ops,
  10. pin::Pin,
  11. time::Duration,
  12. };
  13. use pin_init::*;
  14. #[cfg(feature = "std")]
  15. use std::{
  16. sync::Arc,
  17. thread::{sleep, Builder},
  18. };
  19. #[allow(unused_attributes)]
  20. mod mutex;
  21. use mutex::*;
  22. pub struct StaticInit<T, I> {
  23. cell: UnsafeCell<MaybeUninit<T>>,
  24. init: Cell<Option<I>>,
  25. lock: SpinLock,
  26. present: Cell<bool>,
  27. }
  28. unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
  29. unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
  30. impl<T, I: PinInit<T>> StaticInit<T, I> {
  31. pub const fn new(init: I) -> Self {
  32. Self {
  33. cell: UnsafeCell::new(MaybeUninit::uninit()),
  34. init: Cell::new(Some(init)),
  35. lock: SpinLock::new(),
  36. present: Cell::new(false),
  37. }
  38. }
  39. }
  40. impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
  41. type Target = T;
  42. fn deref(&self) -> &Self::Target {
  43. if self.present.get() {
  44. unsafe { (*self.cell.get()).assume_init_ref() }
  45. } else {
  46. println!("acquire spinlock on static init");
  47. let _guard = self.lock.acquire();
  48. println!("rechecking present...");
  49. std::thread::sleep(std::time::Duration::from_millis(200));
  50. if self.present.get() {
  51. return unsafe { (*self.cell.get()).assume_init_ref() };
  52. }
  53. println!("doing init");
  54. let ptr = self.cell.get().cast::<T>();
  55. match self.init.take() {
  56. Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
  57. None => unsafe { core::hint::unreachable_unchecked() },
  58. }
  59. self.present.set(true);
  60. unsafe { (*self.cell.get()).assume_init_ref() }
  61. }
  62. }
  63. }
  64. pub struct CountInit;
  65. unsafe impl PinInit<CMutex<usize>> for CountInit {
  66. unsafe fn __pinned_init(
  67. self,
  68. slot: *mut CMutex<usize>,
  69. ) -> Result<(), core::convert::Infallible> {
  70. let init = CMutex::new(0);
  71. std::thread::sleep(std::time::Duration::from_millis(1000));
  72. unsafe { init.__pinned_init(slot) }
  73. }
  74. }
  75. pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
  76. fn main() {
  77. #[cfg(feature = "std")]
  78. {
  79. let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
  80. let mut handles = vec![];
  81. let thread_count = 20;
  82. let workload = 1_000;
  83. for i in 0..thread_count {
  84. let mtx = mtx.clone();
  85. handles.push(
  86. Builder::new()
  87. .name(format!("worker #{i}"))
  88. .spawn(move || {
  89. for _ in 0..workload {
  90. *COUNT.lock() += 1;
  91. std::thread::sleep(std::time::Duration::from_millis(10));
  92. *mtx.lock() += 1;
  93. std::thread::sleep(std::time::Duration::from_millis(10));
  94. *COUNT.lock() += 1;
  95. }
  96. println!("{i} halfway");
  97. sleep(Duration::from_millis((i as u64) * 10));
  98. for _ in 0..workload {
  99. std::thread::sleep(std::time::Duration::from_millis(10));
  100. *mtx.lock() += 1;
  101. }
  102. println!("{i} finished");
  103. })
  104. .expect("should not fail"),
  105. );
  106. }
  107. for h in handles {
  108. h.join().expect("thread panicked");
  109. }
  110. println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
  111. assert_eq!(*mtx.lock(), workload * thread_count * 2);
  112. }
  113. }