1use std::cell::RefCell;
2use std::panic::{AssertUnwindSafe, catch_unwind};
3use std::sync::LazyLock;
4use std::sync::mpsc::{TryRecvError, sync_channel};
5
6use polars_utils::with_drop::WithDrop;
7use rayon::{ThreadPool, ThreadPoolBuilder, Yield};
8use tokio::runtime::{Builder, Runtime};
9
10pub struct RAYON;
11
12#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
14thread_local! {
15 static NOOP_POOL: RefCell<ThreadPool> = RefCell::new(
16 ThreadPoolBuilder::new()
17 .use_current_thread()
18 .num_threads(1)
19 .build()
20 .expect("could not create no-op thread pool")
21 );
22}
23
24impl RAYON {
25 pub fn install<OP, R>(&self, op: OP) -> R
26 where
27 OP: FnOnce() -> R + Send,
28 R: Send,
29 {
30 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
31 {
32 op()
33 }
34
35 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
36 {
37 self.with(|p| p.install(op))
38 }
39 }
40
41 pub fn join<A, B, RA, RB>(&self, oper_a: A, oper_b: B) -> (RA, RB)
42 where
43 A: FnOnce() -> RA + Send,
44 B: FnOnce() -> RB + Send,
45 RA: Send,
46 RB: Send,
47 {
48 self.install(|| rayon::join(oper_a, oper_b))
49 }
50
51 pub fn scope<'scope, OP, R>(&self, op: OP) -> R
52 where
53 OP: FnOnce(&rayon::Scope<'scope>) -> R + Send,
54 R: Send,
55 {
56 self.install(|| rayon::scope(op))
57 }
58
59 pub fn spawn<OP>(&self, op: OP)
60 where
61 OP: FnOnce() + Send + 'static,
62 {
63 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
64 {
65 rayon::spawn(op)
66 }
67
68 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
69 {
70 self.with(|p| {
71 p.spawn(op);
72 if p.current_num_threads() == 1 {
73 p.yield_now();
74 }
75 })
76 }
77 }
78
79 pub fn spawn_fifo<OP>(&self, op: OP)
80 where
81 OP: FnOnce() + Send + 'static,
82 {
83 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
84 {
85 rayon::spawn_fifo(op)
86 }
87
88 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
89 {
90 self.with(|p| {
91 p.spawn_fifo(op);
92 if p.current_num_threads() == 1 {
93 p.yield_now();
94 }
95 })
96 }
97 }
98
99 pub fn current_thread_has_pending_tasks(&self) -> Option<bool> {
100 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
101 {
102 None
103 }
104
105 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
106 {
107 self.with(|p| p.current_thread_has_pending_tasks())
108 }
109 }
110
111 pub fn current_thread_index(&self) -> Option<usize> {
112 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
113 {
114 rayon::current_thread_index()
115 }
116
117 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
118 {
119 self.with(|p| p.current_thread_index())
120 }
121 }
122
123 pub fn current_num_threads(&self) -> usize {
124 #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
125 {
126 rayon::current_num_threads()
127 }
128
129 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
130 {
131 self.with(|p| p.current_num_threads())
132 }
133 }
134
135 #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
136 pub fn with<OP, R>(&self, op: OP) -> R
137 where
138 OP: FnOnce(&ThreadPool) -> R + Send,
139 R: Send,
140 {
141 if polars_async::executor::ALLOW_RAYON_THREADS.get()
142 || THREAD_POOL.current_thread_index().is_some()
143 {
144 op(&THREAD_POOL)
145 } else {
146 NOOP_POOL.with(|v| op(&v.borrow()))
147 }
148 }
149
150 pub fn block_on<R: Send, F: FnOnce() -> R + Send>(&self, f: F) -> R {
155 if THREAD_POOL.current_thread_index().is_some() {
156 let (send, recv) = sync_channel(1);
157 let mut opt_f: Option<F> = Some(f);
158 let mut wrap_f = || {
159 let f = AssertUnwindSafe(opt_f.take().unwrap());
160 send.send(catch_unwind(f)).unwrap();
161 };
162
163 let abort = WithDrop::new((), |()| std::process::abort());
167 let ref_wrap_f: &mut (dyn Send + FnMut()) = &mut wrap_f;
168 let static_wrap_f: &'static mut (dyn Send + FnMut() + 'static) =
169 unsafe { core::mem::transmute(ref_wrap_f) };
170 ASYNC.spawn_blocking(static_wrap_f);
171
172 loop {
173 match recv.try_recv() {
174 Ok(r) => {
175 WithDrop::dismiss(abort);
176 match r {
177 Ok(v) => return v,
178 Err(panic) => std::panic::resume_unwind(panic),
179 }
180 },
181 Err(TryRecvError::Empty) => match rayon::yield_now() {
182 Some(Yield::Executed) => {},
183 Some(Yield::Idle) => std::thread::yield_now(),
184 None => unreachable!(),
185 },
186 Err(TryRecvError::Disconnected) => unreachable!(),
187 }
188 }
189 } else {
190 f()
191 }
192 }
193}
194
195#[cfg(not(target_family = "wasm"))] pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
198 let thread_name = std::env::var("POLARS_THREAD_NAME").unwrap_or_else(|_| "polars".to_string());
199 ThreadPoolBuilder::new()
200 .num_threads(
201 std::env::var("POLARS_MAX_THREADS")
202 .map(|s| s.parse::<usize>().expect("integer"))
203 .unwrap_or_else(|_| {
204 std::thread::available_parallelism()
205 .unwrap_or(std::num::NonZeroUsize::new(1).unwrap())
206 .get()
207 }),
208 )
209 .thread_name(move |i| format!("{thread_name}-{i}"))
210 .build()
211 .expect("could not spawn threads")
212});
213
214#[cfg(all(target_os = "emscripten", target_family = "wasm"))] pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
216 ThreadPoolBuilder::new()
217 .num_threads(1)
218 .use_current_thread()
219 .build()
220 .expect("could not create pool")
221});
222
223pub struct AsyncRuntime {
224 rt: Runtime,
225}
226
227impl AsyncRuntime {
228 fn new() -> Self {
229 let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
230 .map(|x| x.parse::<usize>().expect("integer"))
231 .unwrap_or(usize::min(RAYON.current_num_threads(), 32));
232
233 let max_blocking = std::env::var("POLARS_MAX_BLOCKING_THREAD_COUNT")
234 .map(|x| x.parse::<usize>().expect("integer"))
235 .unwrap_or(512);
236
237 if crate::config::verbose() {
238 eprintln!("async thread count: {n_threads}");
239 eprintln!("blocking thread count: {max_blocking}");
240 }
241
242 let rt = Builder::new_multi_thread()
243 .worker_threads(n_threads)
244 .max_blocking_threads(max_blocking)
245 .enable_io()
246 .enable_time()
247 .build()
248 .unwrap();
249
250 Self { rt }
251 }
252
253 pub fn block_in_place_on<F>(&self, future: F) -> F::Output
259 where
260 F: Future,
261 {
262 tokio::task::block_in_place(|| self.rt.block_on(future))
263 }
264
265 pub fn block_on<F>(&self, future: F) -> F::Output
268 where
269 F: Future,
270 {
271 self.rt.block_on(future)
272 }
273
274 pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
276 where
277 F: Future + Send + 'static,
278 F::Output: Send + 'static,
279 {
280 self.rt.spawn(future)
281 }
282
283 pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
285 where
286 F: FnOnce() -> R + Send + 'static,
287 R: Send + 'static,
288 {
289 self.rt.spawn_blocking(f)
290 }
291}
292
293pub static ASYNC: LazyLock<AsyncRuntime> = LazyLock::new(AsyncRuntime::new);