Skip to main content

polars_core/
runtime.rs

1use std::cell::RefCell;
2use std::panic::{AssertUnwindSafe, catch_unwind};
3use std::sync::LazyLock;
4use std::sync::mpsc::{TryRecvError, sync_channel};
5
6use polars_utils::with_drop::WithDrop;
7use rayon::{ThreadPool, ThreadPoolBuilder, Yield};
8use tokio::runtime::{Builder, Runtime};
9
10pub struct RAYON;
11
12// Thread locals to allow disabling threading for specific threads.
13#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
14thread_local! {
15    static NOOP_POOL: RefCell<ThreadPool> = RefCell::new(
16        ThreadPoolBuilder::new()
17            .use_current_thread()
18            .num_threads(1)
19            .build()
20            .expect("could not create no-op thread pool")
21    );
22}
23
24impl RAYON {
25    pub fn install<OP, R>(&self, op: OP) -> R
26    where
27        OP: FnOnce() -> R + Send,
28        R: Send,
29    {
30        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
31        {
32            op()
33        }
34
35        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
36        {
37            self.with(|p| p.install(op))
38        }
39    }
40
41    pub fn join<A, B, RA, RB>(&self, oper_a: A, oper_b: B) -> (RA, RB)
42    where
43        A: FnOnce() -> RA + Send,
44        B: FnOnce() -> RB + Send,
45        RA: Send,
46        RB: Send,
47    {
48        self.install(|| rayon::join(oper_a, oper_b))
49    }
50
51    pub fn scope<'scope, OP, R>(&self, op: OP) -> R
52    where
53        OP: FnOnce(&rayon::Scope<'scope>) -> R + Send,
54        R: Send,
55    {
56        self.install(|| rayon::scope(op))
57    }
58
59    pub fn spawn<OP>(&self, op: OP)
60    where
61        OP: FnOnce() + Send + 'static,
62    {
63        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
64        {
65            rayon::spawn(op)
66        }
67
68        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
69        {
70            self.with(|p| {
71                p.spawn(op);
72                if p.current_num_threads() == 1 {
73                    p.yield_now();
74                }
75            })
76        }
77    }
78
79    pub fn spawn_fifo<OP>(&self, op: OP)
80    where
81        OP: FnOnce() + Send + 'static,
82    {
83        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
84        {
85            rayon::spawn_fifo(op)
86        }
87
88        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
89        {
90            self.with(|p| {
91                p.spawn_fifo(op);
92                if p.current_num_threads() == 1 {
93                    p.yield_now();
94                }
95            })
96        }
97    }
98
99    pub fn current_thread_has_pending_tasks(&self) -> Option<bool> {
100        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
101        {
102            None
103        }
104
105        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
106        {
107            self.with(|p| p.current_thread_has_pending_tasks())
108        }
109    }
110
111    pub fn current_thread_index(&self) -> Option<usize> {
112        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
113        {
114            rayon::current_thread_index()
115        }
116
117        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
118        {
119            self.with(|p| p.current_thread_index())
120        }
121    }
122
123    pub fn current_num_threads(&self) -> usize {
124        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
125        {
126            rayon::current_num_threads()
127        }
128
129        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
130        {
131            self.with(|p| p.current_num_threads())
132        }
133    }
134
135    #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
136    pub fn with<OP, R>(&self, op: OP) -> R
137    where
138        OP: FnOnce(&ThreadPool) -> R + Send,
139        R: Send,
140    {
141        if polars_async::executor::ALLOW_RAYON_THREADS.get()
142            || THREAD_POOL.current_thread_index().is_some()
143        {
144            op(&THREAD_POOL)
145        } else {
146            NOOP_POOL.with(|v| op(&v.borrow()))
147        }
148    }
149
150    /// Calls a blocking function without blocking the rayon thread pool by
151    /// moving it to a different thread.
152    ///
153    /// If this thread isn't a rayon thread this simply calls f directly.
154    pub fn block_on<R: Send, F: FnOnce() -> R + Send>(&self, f: F) -> R {
155        if THREAD_POOL.current_thread_index().is_some() {
156            let (send, recv) = sync_channel(1);
157            let mut opt_f: Option<F> = Some(f);
158            let mut wrap_f = || {
159                let f = AssertUnwindSafe(opt_f.take().unwrap());
160                send.send(catch_unwind(f)).unwrap();
161            };
162
163            // SAFETY: we always await the future to completion before returning from here, meaning
164            // wrap_f stays alive for as long as it needs to. If for whatever reason we unwind we
165            // abort.
166            let abort = WithDrop::new((), |()| std::process::abort());
167            let ref_wrap_f: &mut (dyn Send + FnMut()) = &mut wrap_f;
168            let static_wrap_f: &'static mut (dyn Send + FnMut() + 'static) =
169                unsafe { core::mem::transmute(ref_wrap_f) };
170            ASYNC.spawn_blocking(static_wrap_f);
171
172            loop {
173                match recv.try_recv() {
174                    Ok(r) => {
175                        WithDrop::dismiss(abort);
176                        match r {
177                            Ok(v) => return v,
178                            Err(panic) => std::panic::resume_unwind(panic),
179                        }
180                    },
181                    Err(TryRecvError::Empty) => match rayon::yield_now() {
182                        Some(Yield::Executed) => {},
183                        Some(Yield::Idle) => std::thread::yield_now(),
184                        None => unreachable!(),
185                    },
186                    Err(TryRecvError::Disconnected) => unreachable!(),
187                }
188            }
189        } else {
190            f()
191        }
192    }
193}
194
195// this is re-exported in utils for polars child crates
196#[cfg(not(target_family = "wasm"))] // only use this on non wasm targets
197pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
198    let thread_name = std::env::var("POLARS_THREAD_NAME").unwrap_or_else(|_| "polars".to_string());
199    ThreadPoolBuilder::new()
200        .num_threads(
201            std::env::var("POLARS_MAX_THREADS")
202                .map(|s| s.parse::<usize>().expect("integer"))
203                .unwrap_or_else(|_| {
204                    std::thread::available_parallelism()
205                        .unwrap_or(std::num::NonZeroUsize::new(1).unwrap())
206                        .get()
207                }),
208        )
209        .thread_name(move |i| format!("{thread_name}-{i}"))
210        .build()
211        .expect("could not spawn threads")
212});
213
214#[cfg(all(target_os = "emscripten", target_family = "wasm"))] // Use 1 rayon thread on emscripten
215pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
216    ThreadPoolBuilder::new()
217        .num_threads(1)
218        .use_current_thread()
219        .build()
220        .expect("could not create pool")
221});
222
223pub struct AsyncRuntime {
224    rt: Runtime,
225}
226
227impl AsyncRuntime {
228    fn new() -> Self {
229        let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
230            .map(|x| x.parse::<usize>().expect("integer"))
231            .unwrap_or(usize::min(RAYON.current_num_threads(), 32));
232
233        let max_blocking = std::env::var("POLARS_MAX_BLOCKING_THREAD_COUNT")
234            .map(|x| x.parse::<usize>().expect("integer"))
235            .unwrap_or(512);
236
237        if crate::config::verbose() {
238            eprintln!("async thread count: {n_threads}");
239            eprintln!("blocking thread count: {max_blocking}");
240        }
241
242        let rt = Builder::new_multi_thread()
243            .worker_threads(n_threads)
244            .max_blocking_threads(max_blocking)
245            .enable_io()
246            .enable_time()
247            .build()
248            .unwrap();
249
250        Self { rt }
251    }
252
253    /// Forcibly blocks this thread to evaluate the given future. This can be
254    /// dangerous and lead to deadlocks if called re-entrantly on an async
255    /// worker thread as the entire thread pool can end up blocking, leading to
256    /// a deadlock. If you want to prevent this use block_on, which will panic
257    /// if called from an async thread.
258    pub fn block_in_place_on<F>(&self, future: F) -> F::Output
259    where
260        F: Future,
261    {
262        tokio::task::block_in_place(|| self.rt.block_on(future))
263    }
264
265    /// Blocks this thread to evaluate the given future. Panics if the current
266    /// thread is an async runtime worker thread.
267    pub fn block_on<F>(&self, future: F) -> F::Output
268    where
269        F: Future,
270    {
271        self.rt.block_on(future)
272    }
273
274    /// Spawns a future onto the Tokio runtime (see [`tokio::runtime::Runtime::spawn`]).
275    pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
276    where
277        F: Future + Send + 'static,
278        F::Output: Send + 'static,
279    {
280        self.rt.spawn(future)
281    }
282
283    // See [`tokio::runtime::Runtime::spawn_blocking`].
284    pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
285    where
286        F: FnOnce() -> R + Send + 'static,
287        R: Send + 'static,
288    {
289        self.rt.spawn_blocking(f)
290    }
291}
292
293pub static ASYNC: LazyLock<AsyncRuntime> = LazyLock::new(AsyncRuntime::new);