polars_io/
pl_async.rs

1use std::error::Error;
2use std::future::Future;
3use std::ops::Deref;
4use std::sync::LazyLock;
5use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
6
7use polars_core::POOL;
8use polars_core::config::{self, verbose};
9use tokio::runtime::{Builder, Runtime};
10use tokio::sync::Semaphore;
11
12static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
13pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
14
15/// Used to determine chunks when splitting large ranges, or combining small
16/// ranges.
17static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
18    let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
19        .as_deref()
20        .map(|x| x.parse().expect("integer"))
21        .unwrap_or(64 * 1024 * 1024);
22
23    if config::verbose() {
24        eprintln!("async download_chunk_size: {}", v)
25    }
26
27    v
28});
29
30pub(super) fn get_download_chunk_size() -> usize {
31    *DOWNLOAD_CHUNK_SIZE
32}
33
34static UPLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
35    let v: usize = std::env::var("POLARS_UPLOAD_CHUNK_SIZE")
36        .as_deref()
37        .map(|x| x.parse().expect("integer"))
38        .unwrap_or(64 * 1024 * 1024);
39
40    if config::verbose() {
41        eprintln!("async upload_chunk_size: {}", v)
42    }
43
44    v
45});
46
47pub(super) fn get_upload_chunk_size() -> usize {
48    *UPLOAD_CHUNK_SIZE
49}
50
51pub trait GetSize {
52    fn size(&self) -> u64;
53}
54
55impl GetSize for bytes::Bytes {
56    fn size(&self) -> u64 {
57        self.len() as u64
58    }
59}
60
61impl<T: GetSize> GetSize for Vec<T> {
62    fn size(&self) -> u64 {
63        self.iter().map(|v| v.size()).sum()
64    }
65}
66
67impl<T: GetSize, E: Error> GetSize for Result<T, E> {
68    fn size(&self) -> u64 {
69        match self {
70            Ok(v) => v.size(),
71            Err(_) => 0,
72        }
73    }
74}
75
76#[cfg(feature = "cloud")]
77pub(crate) struct Size(u64);
78
79#[cfg(feature = "cloud")]
80impl GetSize for Size {
81    fn size(&self) -> u64 {
82        self.0
83    }
84}
85#[cfg(feature = "cloud")]
86impl From<u64> for Size {
87    fn from(value: u64) -> Self {
88        Self(value)
89    }
90}
91
92enum Optimization {
93    Step,
94    Accept,
95    Finished,
96}
97
98struct SemaphoreTuner {
99    previous_download_speed: u64,
100    last_tune: std::time::Instant,
101    downloaded: AtomicU64,
102    download_time: AtomicU64,
103    opt_state: Optimization,
104    increments: u32,
105}
106
107impl SemaphoreTuner {
108    fn new() -> Self {
109        Self {
110            previous_download_speed: 0,
111            last_tune: std::time::Instant::now(),
112            downloaded: AtomicU64::new(0),
113            download_time: AtomicU64::new(0),
114            opt_state: Optimization::Step,
115            increments: 0,
116        }
117    }
118    fn should_tune(&self) -> bool {
119        match self.opt_state {
120            Optimization::Finished => false,
121            _ => self.last_tune.elapsed().as_millis() > 350,
122        }
123    }
124
125    fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
126        self.downloaded
127            .fetch_add(downloaded_bytes, Ordering::Relaxed);
128        self.download_time
129            .fetch_add(download_time, Ordering::Relaxed);
130    }
131
132    fn increment(&mut self, semaphore: &Semaphore) {
133        semaphore.add_permits(1);
134        self.increments += 1;
135    }
136
137    fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
138        let bytes_downloaded = self.downloaded.fetch_add(0, Ordering::Relaxed);
139        let time_elapsed = self.download_time.fetch_add(0, Ordering::Relaxed);
140        let download_speed = bytes_downloaded
141            .checked_div(time_elapsed)
142            .unwrap_or_default();
143
144        let increased = download_speed > self.previous_download_speed;
145        self.previous_download_speed = download_speed;
146        match self.opt_state {
147            Optimization::Step => {
148                self.increment(semaphore);
149                self.opt_state = Optimization::Accept
150            },
151            Optimization::Accept => {
152                // Accept the step
153                if increased {
154                    // Set new step
155                    self.increment(semaphore);
156                    // Keep accept state to check next iteration
157                }
158                // Decline the step
159                else {
160                    self.opt_state = Optimization::Finished;
161                    FINISHED_TUNING.store(true, Ordering::Relaxed);
162                    if verbose() {
163                        eprintln!(
164                            "concurrency tuner finished after adding {} steps",
165                            self.increments
166                        )
167                    }
168                    // Finished.
169                    return true;
170                }
171            },
172            Optimization::Finished => {},
173        }
174        self.last_tune = std::time::Instant::now();
175        // Not finished.
176        false
177    }
178}
179static INCR: AtomicU8 = AtomicU8::new(0);
180static FINISHED_TUNING: AtomicBool = AtomicBool::new(false);
181static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
182    std::sync::OnceLock::new();
183
184fn get_semaphore() -> &'static (Semaphore, u32) {
185    CONCURRENCY_BUDGET.get_or_init(|| {
186        let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
187            .map(|s| {
188                let budget = s.parse::<usize>().expect("integer");
189                FINISHED_TUNING.store(true, Ordering::Relaxed);
190                budget
191            })
192            .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
193        (Semaphore::new(permits), permits as u32)
194    })
195}
196
197pub(crate) fn get_concurrency_limit() -> u32 {
198    get_semaphore().1
199}
200
201pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
202where
203    F: FnOnce() -> Fut,
204    Fut: Future,
205    Fut::Output: GetSize,
206{
207    let (semaphore, initial_budget) = get_semaphore();
208
209    // This would never finish otherwise.
210    assert!(requested_budget <= *initial_budget);
211
212    // Keep permit around.
213    // On drop it is returned to the semaphore.
214    let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
215
216    let now = std::time::Instant::now();
217    let res = callable().await;
218
219    if FINISHED_TUNING.load(Ordering::Relaxed) || res.size() == 0 {
220        return res;
221    }
222
223    let duration = now.elapsed().as_millis() as u64;
224    let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
225
226    let Ok(tuner) = permit_store.try_read() else {
227        return res;
228    };
229    // Keep track of download speed
230    tuner.add_stats(res.size(), duration);
231
232    // We only tune every n ms
233    if !tuner.should_tune() {
234        return res;
235    }
236    // Drop the read tuner before trying to acquire a writer
237    drop(tuner);
238
239    // Reduce locking by letting only 1 in 5 tasks lock the tuner
240    if (INCR.fetch_add(1, Ordering::Relaxed) % 5) != 0 {
241        return res;
242    }
243    // Never lock as we will deadlock. This can run under rayon
244    let Ok(mut tuner) = permit_store.try_write() else {
245        return res;
246    };
247    let finished = tuner.tune(semaphore);
248    if finished {
249        drop(_permit_acq);
250        // Undo the last step
251        let undo = semaphore.acquire().await.unwrap();
252        std::mem::forget(undo)
253    }
254    res
255}
256
257pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
258where
259    F: FnOnce() -> Fut,
260    Fut: Future,
261{
262    let (semaphore, initial_budget) = get_semaphore();
263
264    // This would never finish otherwise.
265    assert!(requested_budget <= *initial_budget);
266
267    // Keep permit around.
268    // On drop it is returned to the semaphore.
269    let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
270
271    callable().await
272}
273
274pub struct RuntimeManager {
275    rt: Runtime,
276}
277
278impl RuntimeManager {
279    fn new() -> Self {
280        let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
281            .map(|x| x.parse::<usize>().expect("integer"))
282            .unwrap_or(POOL.current_num_threads().clamp(1, 4));
283
284        if polars_core::config::verbose() {
285            eprintln!("async thread count: {}", n_threads);
286        }
287
288        let rt = Builder::new_multi_thread()
289            .worker_threads(n_threads)
290            .enable_io()
291            .enable_time()
292            .build()
293            .unwrap();
294
295        Self { rt }
296    }
297
298    /// Shorthand for `tokio::task::block_in_place(|| block_on(f))`. This is a variant of `block_on`
299    /// that is safe to call from if the current thread has already entered the async runtime, or
300    /// is a rayon thread.
301    ///
302    /// # Safety
303    /// The tokio runtime flavor is multi-threaded.
304    pub fn block_in_place_on<F>(&self, future: F) -> F::Output
305    where
306        F: Future,
307    {
308        tokio::task::block_in_place(|| self.rt.block_on(future))
309    }
310
311    /// Note: `block_in_place_on` should be used instead if the current thread is a rayon thread or
312    /// has already entered the async runtime.
313    pub fn block_on<F>(&self, future: F) -> F::Output
314    where
315        F: Future,
316    {
317        self.rt.block_on(future)
318    }
319
320    /// Spawns a future onto the Tokio runtime (see [`tokio::runtime::Runtime::spawn`]).
321    pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
322    where
323        F: Future + Send + 'static,
324        F::Output: Send + 'static,
325    {
326        self.rt.spawn(future)
327    }
328
329    // See [`tokio::runtime::Runtime::spawn_blocking`].
330    pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
331    where
332        F: FnOnce() -> R + Send + 'static,
333        R: Send + 'static,
334    {
335        self.rt.spawn_blocking(f)
336    }
337
338    /// Run a task on the rayon threadpool. To avoid deadlocks, if the current thread is already a
339    /// rayon thread, the task is executed on the current thread after tokio's `block_in_place` is
340    /// used to spawn another thread to poll futures.
341    pub async fn spawn_rayon<F, O>(&self, func: F) -> O
342    where
343        F: FnOnce() -> O + Send + Sync + 'static,
344        O: Send + Sync + 'static,
345    {
346        if POOL.current_thread_index().is_some() {
347            // We are a rayon thread, so we can't use POOL.spawn as it would mean we spawn a task and block until
348            // another rayon thread executes it - we would deadlock if all rayon threads did this.
349            // Safety: The tokio runtime flavor is multi-threaded.
350            tokio::task::block_in_place(func)
351        } else {
352            let (tx, rx) = tokio::sync::oneshot::channel();
353
354            let func = move || {
355                let out = func();
356                // Don't unwrap send attempt - async task could be cancelled.
357                let _ = tx.send(out);
358            };
359
360            POOL.spawn(func);
361
362            rx.await.unwrap()
363        }
364    }
365}
366
367static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
368
369pub fn get_runtime() -> &'static RuntimeManager {
370    RUNTIME.deref()
371}