1use std::cell::RefCell;
2use std::panic::{AssertUnwindSafe, catch_unwind};
3use std::sync::LazyLock;
4use std::sync::mpsc::{TryRecvError, sync_channel};
5
6use polars_utils::with_drop::WithDrop;
7use rayon::{ThreadPool, ThreadPoolBuilder, Yield};
8
9pub struct RAYON;
10
11#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
13thread_local! {
14 static NOOP_POOL: RefCell<ThreadPool> = RefCell::new(
15 ThreadPoolBuilder::new()
16 .use_current_thread()
17 .num_threads(1)
18 .build()
19 .expect("could not create no-op thread pool")
20 );
21}
22
23impl RAYON {
24 pub fn install<OP, R>(&self, op: OP) -> R
25 where
26 OP: FnOnce() -> R + Send,
27 R: Send,
28 {
29 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
30 {
31 op()
32 }
33
34 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
35 {
36 self.with(|p| p.install(op))
37 }
38 }
39
40 pub fn join<A, B, RA, RB>(&self, oper_a: A, oper_b: B) -> (RA, RB)
41 where
42 A: FnOnce() -> RA + Send,
43 B: FnOnce() -> RB + Send,
44 RA: Send,
45 RB: Send,
46 {
47 self.install(|| rayon::join(oper_a, oper_b))
48 }
49
50 pub fn scope<'scope, OP, R>(&self, op: OP) -> R
51 where
52 OP: FnOnce(&rayon::Scope<'scope>) -> R + Send,
53 R: Send,
54 {
55 self.install(|| rayon::scope(op))
56 }
57
58 pub fn spawn<OP>(&self, op: OP)
59 where
60 OP: FnOnce() + Send + 'static,
61 {
62 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
63 {
64 rayon::spawn(op)
65 }
66
67 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
68 {
69 self.with(|p| {
70 p.spawn(op);
71 if p.current_num_threads() == 1 {
72 p.yield_now();
73 }
74 })
75 }
76 }
77
78 pub fn spawn_fifo<OP>(&self, op: OP)
79 where
80 OP: FnOnce() + Send + 'static,
81 {
82 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
83 {
84 rayon::spawn_fifo(op)
85 }
86
87 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
88 {
89 self.with(|p| {
90 p.spawn_fifo(op);
91 if p.current_num_threads() == 1 {
92 p.yield_now();
93 }
94 })
95 }
96 }
97
98 pub fn current_thread_has_pending_tasks(&self) -> Option<bool> {
99 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
100 {
101 None
102 }
103
104 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
105 {
106 self.with(|p| p.current_thread_has_pending_tasks())
107 }
108 }
109
110 pub fn current_thread_index(&self) -> Option<usize> {
111 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
112 {
113 rayon::current_thread_index()
114 }
115
116 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
117 {
118 self.with(|p| p.current_thread_index())
119 }
120 }
121
122 pub fn current_num_threads(&self) -> usize {
123 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
124 {
125 rayon::current_num_threads()
126 }
127
128 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
129 {
130 self.with(|p| p.current_num_threads())
131 }
132 }
133
134 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
135 pub fn with<OP, R>(&self, op: OP) -> R
136 where
137 OP: FnOnce(&ThreadPool) -> R + Send,
138 R: Send,
139 {
140 if polars_async::executor::ALLOW_RAYON_THREADS.get()
141 || THREAD_POOL.current_thread_index().is_some()
142 {
143 op(&THREAD_POOL)
144 } else {
145 NOOP_POOL.with(|v| op(&v.borrow()))
146 }
147 }
148
149 pub fn block_on<R: Send, F: FnOnce() -> R + Send>(&self, f: F) -> R {
154 if THREAD_POOL.current_thread_index().is_some() {
155 let (send, recv) = sync_channel(1);
156 let mut opt_f: Option<F> = Some(f);
157 let mut wrap_f = || {
158 let f = AssertUnwindSafe(opt_f.take().unwrap());
159 send.send(catch_unwind(f)).unwrap();
160 };
161
162 let abort = WithDrop::new((), |()| std::process::abort());
166 let ref_wrap_f: &mut (dyn Send + FnMut()) = &mut wrap_f;
167 let static_wrap_f: &'static mut (dyn Send + FnMut() + 'static) =
168 unsafe { core::mem::transmute(ref_wrap_f) };
169 ASYNC.spawn_blocking(static_wrap_f);
170
171 loop {
172 match recv.try_recv() {
173 Ok(r) => {
174 WithDrop::dismiss(abort);
175 match r {
176 Ok(v) => return v,
177 Err(panic) => std::panic::resume_unwind(panic),
178 }
179 },
180 Err(TryRecvError::Empty) => match rayon::yield_now() {
181 Some(Yield::Executed) => {},
182 Some(Yield::Idle) => std::thread::yield_now(),
183 None => unreachable!(),
184 },
185 Err(TryRecvError::Disconnected) => unreachable!(),
186 }
187 }
188 } else {
189 f()
190 }
191 }
192}
193
194#[cfg(not(target_family = "wasm"))] pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
197 let thread_name = std::env::var("POLARS_THREAD_NAME").unwrap_or_else(|_| "polars".to_string());
198 ThreadPoolBuilder::new()
199 .num_threads(polars_config::config().max_threads())
200 .thread_name(move |i| format!("{thread_name}-{i}"))
201 .build()
202 .expect("could not spawn threads")
203});
204
205#[cfg(all(target_os = "emscripten", target_family = "wasm"))] pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
207 ThreadPoolBuilder::new()
208 .num_threads(1)
209 .use_current_thread()
210 .build()
211 .expect("could not create pool")
212});
213
214pub use polars_async::ASYNC;