Skip to main content

polars_core/
runtime.rs

1use std::cell::RefCell;
2use std::panic::{AssertUnwindSafe, catch_unwind};
3use std::sync::LazyLock;
4use std::sync::mpsc::{TryRecvError, sync_channel};
5
6use polars_utils::with_drop::WithDrop;
7use rayon::{ThreadPool, ThreadPoolBuilder, Yield};
8
9pub struct RAYON;
10
11// Thread locals to allow disabling threading for specific threads.
12#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
13thread_local! {
14    static NOOP_POOL: RefCell<ThreadPool> = RefCell::new(
15        ThreadPoolBuilder::new()
16            .use_current_thread()
17            .num_threads(1)
18            .build()
19            .expect("could not create no-op thread pool")
20    );
21}
22
23impl RAYON {
24    pub fn install<OP, R>(&self, op: OP) -> R
25    where
26        OP: FnOnce() -> R + Send,
27        R: Send,
28    {
29        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
30        {
31            op()
32        }
33
34        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
35        {
36            self.with(|p| p.install(op))
37        }
38    }
39
40    pub fn join<A, B, RA, RB>(&self, oper_a: A, oper_b: B) -> (RA, RB)
41    where
42        A: FnOnce() -> RA + Send,
43        B: FnOnce() -> RB + Send,
44        RA: Send,
45        RB: Send,
46    {
47        self.install(|| rayon::join(oper_a, oper_b))
48    }
49
50    pub fn scope<'scope, OP, R>(&self, op: OP) -> R
51    where
52        OP: FnOnce(&rayon::Scope<'scope>) -> R + Send,
53        R: Send,
54    {
55        self.install(|| rayon::scope(op))
56    }
57
58    pub fn spawn<OP>(&self, op: OP)
59    where
60        OP: FnOnce() + Send + 'static,
61    {
62        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
63        {
64            rayon::spawn(op)
65        }
66
67        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
68        {
69            self.with(|p| {
70                p.spawn(op);
71                if p.current_num_threads() == 1 {
72                    p.yield_now();
73                }
74            })
75        }
76    }
77
78    pub fn spawn_fifo<OP>(&self, op: OP)
79    where
80        OP: FnOnce() + Send + 'static,
81    {
82        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
83        {
84            rayon::spawn_fifo(op)
85        }
86
87        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
88        {
89            self.with(|p| {
90                p.spawn_fifo(op);
91                if p.current_num_threads() == 1 {
92                    p.yield_now();
93                }
94            })
95        }
96    }
97
98    pub fn current_thread_has_pending_tasks(&self) -> Option<bool> {
99        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
100        {
101            None
102        }
103
104        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
105        {
106            self.with(|p| p.current_thread_has_pending_tasks())
107        }
108    }
109
110    pub fn current_thread_index(&self) -> Option<usize> {
111        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
112        {
113            rayon::current_thread_index()
114        }
115
116        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
117        {
118            self.with(|p| p.current_thread_index())
119        }
120    }
121
122    pub fn current_num_threads(&self) -> usize {
123        #[cfg(not(any(target_os = "emscripten", not(target_family = "wasm"))))]
124        {
125            rayon::current_num_threads()
126        }
127
128        #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
129        {
130            self.with(|p| p.current_num_threads())
131        }
132    }
133
134    #[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
135    pub fn with<OP, R>(&self, op: OP) -> R
136    where
137        OP: FnOnce(&ThreadPool) -> R + Send,
138        R: Send,
139    {
140        if polars_async::executor::ALLOW_RAYON_THREADS.get()
141            || THREAD_POOL.current_thread_index().is_some()
142        {
143            op(&THREAD_POOL)
144        } else {
145            NOOP_POOL.with(|v| op(&v.borrow()))
146        }
147    }
148
149    /// Calls a blocking function without blocking the rayon thread pool by
150    /// moving it to a different thread.
151    ///
152    /// If this thread isn't a rayon thread this simply calls f directly.
153    pub fn block_on<R: Send, F: FnOnce() -> R + Send>(&self, f: F) -> R {
154        if THREAD_POOL.current_thread_index().is_some() {
155            let (send, recv) = sync_channel(1);
156            let mut opt_f: Option<F> = Some(f);
157            let mut wrap_f = || {
158                let f = AssertUnwindSafe(opt_f.take().unwrap());
159                send.send(catch_unwind(f)).unwrap();
160            };
161
162            // SAFETY: we always await the future to completion before returning from here, meaning
163            // wrap_f stays alive for as long as it needs to. If for whatever reason we unwind we
164            // abort.
165            let abort = WithDrop::new((), |()| std::process::abort());
166            let ref_wrap_f: &mut (dyn Send + FnMut()) = &mut wrap_f;
167            let static_wrap_f: &'static mut (dyn Send + FnMut() + 'static) =
168                unsafe { core::mem::transmute(ref_wrap_f) };
169            ASYNC.spawn_blocking(static_wrap_f);
170
171            loop {
172                match recv.try_recv() {
173                    Ok(r) => {
174                        WithDrop::dismiss(abort);
175                        match r {
176                            Ok(v) => return v,
177                            Err(panic) => std::panic::resume_unwind(panic),
178                        }
179                    },
180                    Err(TryRecvError::Empty) => match rayon::yield_now() {
181                        Some(Yield::Executed) => {},
182                        Some(Yield::Idle) => std::thread::yield_now(),
183                        None => unreachable!(),
184                    },
185                    Err(TryRecvError::Disconnected) => unreachable!(),
186                }
187            }
188        } else {
189            f()
190        }
191    }
192}
193
194// this is re-exported in utils for polars child crates
195#[cfg(not(target_family = "wasm"))] // only use this on non wasm targets
196pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
197    let thread_name = std::env::var("POLARS_THREAD_NAME").unwrap_or_else(|_| "polars".to_string());
198    ThreadPoolBuilder::new()
199        .num_threads(polars_config::config().max_threads())
200        .thread_name(move |i| format!("{thread_name}-{i}"))
201        .build()
202        .expect("could not spawn threads")
203});
204
205#[cfg(all(target_os = "emscripten", target_family = "wasm"))] // Use 1 rayon thread on emscripten
206pub static THREAD_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
207    ThreadPoolBuilder::new()
208        .num_threads(1)
209        .use_current_thread()
210        .build()
211        .expect("could not create pool")
212});
213
214pub use polars_async::ASYNC;