polars_io/
pl_async.rs

1use std::error::Error;
2use std::future::Future;
3use std::ops::Deref;
4use std::sync::LazyLock;
5
6use polars_core::POOL;
7use polars_core::config::{self, verbose};
8use polars_utils::relaxed_cell::RelaxedCell;
9use tokio::runtime::{Builder, Runtime};
10use tokio::sync::Semaphore;
11
12static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
13pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
14
15/// Used to determine chunks when splitting large ranges, or combining small
16/// ranges.
17static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
18    let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
19        .as_deref()
20        .map(|x| x.parse().expect("integer"))
21        .unwrap_or(64 * 1024 * 1024);
22
23    if config::verbose() {
24        eprintln!("async download_chunk_size: {v}")
25    }
26
27    v
28});
29
30pub(super) fn get_download_chunk_size() -> usize {
31    *DOWNLOAD_CHUNK_SIZE
32}
33
34static UPLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
35    let v: usize = std::env::var("POLARS_UPLOAD_CHUNK_SIZE")
36        .as_deref()
37        .map(|x| x.parse().expect("integer"))
38        .unwrap_or(64 * 1024 * 1024);
39
40    if config::verbose() {
41        eprintln!("async upload_chunk_size: {v}")
42    }
43
44    v
45});
46
47pub(super) fn get_upload_chunk_size() -> usize {
48    *UPLOAD_CHUNK_SIZE
49}
50
51pub trait GetSize {
52    fn size(&self) -> u64;
53}
54
55impl GetSize for bytes::Bytes {
56    fn size(&self) -> u64 {
57        self.len() as u64
58    }
59}
60
61impl<T: GetSize> GetSize for Vec<T> {
62    fn size(&self) -> u64 {
63        self.iter().map(|v| v.size()).sum()
64    }
65}
66
67impl<T: GetSize, E: Error> GetSize for Result<T, E> {
68    fn size(&self) -> u64 {
69        match self {
70            Ok(v) => v.size(),
71            Err(_) => 0,
72        }
73    }
74}
75
76#[cfg(feature = "cloud")]
77pub(crate) struct Size(u64);
78
79#[cfg(feature = "cloud")]
80impl GetSize for Size {
81    fn size(&self) -> u64 {
82        self.0
83    }
84}
85#[cfg(feature = "cloud")]
86impl From<u64> for Size {
87    fn from(value: u64) -> Self {
88        Self(value)
89    }
90}
91
92enum Optimization {
93    Step,
94    Accept,
95    Finished,
96}
97
98struct SemaphoreTuner {
99    previous_download_speed: u64,
100    last_tune: std::time::Instant,
101    downloaded: RelaxedCell<u64>,
102    download_time: RelaxedCell<u64>,
103    opt_state: Optimization,
104    increments: u32,
105}
106
107impl SemaphoreTuner {
108    fn new() -> Self {
109        Self {
110            previous_download_speed: 0,
111            last_tune: std::time::Instant::now(),
112            downloaded: RelaxedCell::from(0),
113            download_time: RelaxedCell::from(0),
114            opt_state: Optimization::Step,
115            increments: 0,
116        }
117    }
118    fn should_tune(&self) -> bool {
119        match self.opt_state {
120            Optimization::Finished => false,
121            _ => self.last_tune.elapsed().as_millis() > 350,
122        }
123    }
124
125    fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
126        self.downloaded.fetch_add(downloaded_bytes);
127        self.download_time.fetch_add(download_time);
128    }
129
130    fn increment(&mut self, semaphore: &Semaphore) {
131        semaphore.add_permits(1);
132        self.increments += 1;
133    }
134
135    fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
136        let bytes_downloaded = self.downloaded.load();
137        let time_elapsed = self.download_time.load();
138        let download_speed = bytes_downloaded
139            .checked_div(time_elapsed)
140            .unwrap_or_default();
141
142        let increased = download_speed > self.previous_download_speed;
143        self.previous_download_speed = download_speed;
144        match self.opt_state {
145            Optimization::Step => {
146                self.increment(semaphore);
147                self.opt_state = Optimization::Accept
148            },
149            Optimization::Accept => {
150                // Accept the step
151                if increased {
152                    // Set new step
153                    self.increment(semaphore);
154                    // Keep accept state to check next iteration
155                }
156                // Decline the step
157                else {
158                    self.opt_state = Optimization::Finished;
159                    FINISHED_TUNING.store(true);
160                    if verbose() {
161                        eprintln!(
162                            "concurrency tuner finished after adding {} steps",
163                            self.increments
164                        )
165                    }
166                    // Finished.
167                    return true;
168                }
169            },
170            Optimization::Finished => {},
171        }
172        self.last_tune = std::time::Instant::now();
173        // Not finished.
174        false
175    }
176}
177static INCR: RelaxedCell<u64> = RelaxedCell::new_u64(0);
178static FINISHED_TUNING: RelaxedCell<bool> = RelaxedCell::new_bool(false);
179static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
180    std::sync::OnceLock::new();
181
182fn get_semaphore() -> &'static (Semaphore, u32) {
183    CONCURRENCY_BUDGET.get_or_init(|| {
184        let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
185            .map(|s| {
186                let budget = s.parse::<usize>().expect("integer");
187                FINISHED_TUNING.store(true);
188                budget
189            })
190            .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
191        (Semaphore::new(permits), permits as u32)
192    })
193}
194
195pub(crate) fn get_concurrency_limit() -> u32 {
196    get_semaphore().1
197}
198
199pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
200where
201    F: FnOnce() -> Fut,
202    Fut: Future,
203    Fut::Output: GetSize,
204{
205    let (semaphore, initial_budget) = get_semaphore();
206
207    // This would never finish otherwise.
208    assert!(requested_budget <= *initial_budget);
209
210    // Keep permit around.
211    // On drop it is returned to the semaphore.
212    let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
213
214    let now = std::time::Instant::now();
215    let res = callable().await;
216
217    if FINISHED_TUNING.load() || res.size() == 0 {
218        return res;
219    }
220
221    let duration = now.elapsed().as_millis() as u64;
222    let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
223
224    let Ok(tuner) = permit_store.try_read() else {
225        return res;
226    };
227    // Keep track of download speed
228    tuner.add_stats(res.size(), duration);
229
230    // We only tune every n ms
231    if !tuner.should_tune() {
232        return res;
233    }
234    // Drop the read tuner before trying to acquire a writer
235    drop(tuner);
236
237    // Reduce locking by letting only 1 in 5 tasks lock the tuner
238    if !INCR.fetch_add(1).is_multiple_of(5) {
239        return res;
240    }
241    // Never lock as we will deadlock. This can run under rayon
242    let Ok(mut tuner) = permit_store.try_write() else {
243        return res;
244    };
245    let finished = tuner.tune(semaphore);
246    if finished {
247        drop(_permit_acq);
248        // Undo the last step
249        let undo = semaphore.acquire().await.unwrap();
250        std::mem::forget(undo)
251    }
252    res
253}
254
255pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
256where
257    F: FnOnce() -> Fut,
258    Fut: Future,
259{
260    let (semaphore, initial_budget) = get_semaphore();
261
262    // This would never finish otherwise.
263    assert!(requested_budget <= *initial_budget);
264
265    // Keep permit around.
266    // On drop it is returned to the semaphore.
267    let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
268
269    callable().await
270}
271
272pub struct RuntimeManager {
273    rt: Runtime,
274}
275
276impl RuntimeManager {
277    fn new() -> Self {
278        let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
279            .map(|x| x.parse::<usize>().expect("integer"))
280            .unwrap_or(POOL.current_num_threads().clamp(1, 4));
281
282        if polars_core::config::verbose() {
283            eprintln!("async thread count: {n_threads}");
284        }
285
286        let rt = Builder::new_multi_thread()
287            .worker_threads(n_threads)
288            .enable_io()
289            .enable_time()
290            .build()
291            .unwrap();
292
293        Self { rt }
294    }
295
296    /// Forcibly blocks this thread to evaluate the given future. This can be
297    /// dangerous and lead to deadlocks if called re-entrantly on an async
298    /// worker thread as the entire thread pool can end up blocking, leading to
299    /// a deadlock. If you want to prevent this use block_on, which will panic
300    /// if called from an async thread.
301    pub fn block_in_place_on<F>(&self, future: F) -> F::Output
302    where
303        F: Future,
304    {
305        tokio::task::block_in_place(|| self.rt.block_on(future))
306    }
307
308    /// Blocks this thread to evaluate the given future. Panics if the current
309    /// thread is an async runtime worker thread.
310    pub fn block_on<F>(&self, future: F) -> F::Output
311    where
312        F: Future,
313    {
314        self.rt.block_on(future)
315    }
316
317    /// Spawns a future onto the Tokio runtime (see [`tokio::runtime::Runtime::spawn`]).
318    pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
319    where
320        F: Future + Send + 'static,
321        F::Output: Send + 'static,
322    {
323        self.rt.spawn(future)
324    }
325
326    // See [`tokio::runtime::Runtime::spawn_blocking`].
327    pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
328    where
329        F: FnOnce() -> R + Send + 'static,
330        R: Send + 'static,
331    {
332        self.rt.spawn_blocking(f)
333    }
334
335    /// Run a task on the rayon threadpool. To avoid deadlocks, if the current thread is already a
336    /// rayon thread, the task is executed on the current thread after tokio's `block_in_place` is
337    /// used to spawn another thread to poll futures.
338    pub async fn spawn_rayon<F, O>(&self, func: F) -> O
339    where
340        F: FnOnce() -> O + Send + Sync + 'static,
341        O: Send + Sync + 'static,
342    {
343        if POOL.current_thread_index().is_some() {
344            // We are a rayon thread, so we can't use POOL.spawn as it would mean we spawn a task and block until
345            // another rayon thread executes it - we would deadlock if all rayon threads did this.
346            // Safety: The tokio runtime flavor is multi-threaded.
347            tokio::task::block_in_place(func)
348        } else {
349            let (tx, rx) = tokio::sync::oneshot::channel();
350
351            let func = move || {
352                let out = func();
353                // Don't unwrap send attempt - async task could be cancelled.
354                let _ = tx.send(out);
355            };
356
357            POOL.spawn(func);
358
359            rx.await.unwrap()
360        }
361    }
362}
363
364static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
365
366pub fn get_runtime() -> &'static RuntimeManager {
367    RUNTIME.deref()
368}