1use std::error::Error;
2use std::future::Future;
3use std::ops::Deref;
4use std::sync::LazyLock;
5use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
6
7use polars_core::POOL;
8use polars_core::config::{self, verbose};
9use tokio::runtime::{Builder, Runtime};
10use tokio::sync::Semaphore;
11
12static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
13pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
14
15static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
18 let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
19 .as_deref()
20 .map(|x| x.parse().expect("integer"))
21 .unwrap_or(64 * 1024 * 1024);
22
23 if config::verbose() {
24 eprintln!("async download_chunk_size: {}", v)
25 }
26
27 v
28});
29
30pub(super) fn get_download_chunk_size() -> usize {
31 *DOWNLOAD_CHUNK_SIZE
32}
33
34static UPLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
35 let v: usize = std::env::var("POLARS_UPLOAD_CHUNK_SIZE")
36 .as_deref()
37 .map(|x| x.parse().expect("integer"))
38 .unwrap_or(64 * 1024 * 1024);
39
40 if config::verbose() {
41 eprintln!("async upload_chunk_size: {}", v)
42 }
43
44 v
45});
46
47pub(super) fn get_upload_chunk_size() -> usize {
48 *UPLOAD_CHUNK_SIZE
49}
50
51pub trait GetSize {
52 fn size(&self) -> u64;
53}
54
55impl GetSize for bytes::Bytes {
56 fn size(&self) -> u64 {
57 self.len() as u64
58 }
59}
60
61impl<T: GetSize> GetSize for Vec<T> {
62 fn size(&self) -> u64 {
63 self.iter().map(|v| v.size()).sum()
64 }
65}
66
67impl<T: GetSize, E: Error> GetSize for Result<T, E> {
68 fn size(&self) -> u64 {
69 match self {
70 Ok(v) => v.size(),
71 Err(_) => 0,
72 }
73 }
74}
75
76#[cfg(feature = "cloud")]
77pub(crate) struct Size(u64);
78
79#[cfg(feature = "cloud")]
80impl GetSize for Size {
81 fn size(&self) -> u64 {
82 self.0
83 }
84}
85#[cfg(feature = "cloud")]
86impl From<u64> for Size {
87 fn from(value: u64) -> Self {
88 Self(value)
89 }
90}
91
92enum Optimization {
93 Step,
94 Accept,
95 Finished,
96}
97
98struct SemaphoreTuner {
99 previous_download_speed: u64,
100 last_tune: std::time::Instant,
101 downloaded: AtomicU64,
102 download_time: AtomicU64,
103 opt_state: Optimization,
104 increments: u32,
105}
106
107impl SemaphoreTuner {
108 fn new() -> Self {
109 Self {
110 previous_download_speed: 0,
111 last_tune: std::time::Instant::now(),
112 downloaded: AtomicU64::new(0),
113 download_time: AtomicU64::new(0),
114 opt_state: Optimization::Step,
115 increments: 0,
116 }
117 }
118 fn should_tune(&self) -> bool {
119 match self.opt_state {
120 Optimization::Finished => false,
121 _ => self.last_tune.elapsed().as_millis() > 350,
122 }
123 }
124
125 fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
126 self.downloaded
127 .fetch_add(downloaded_bytes, Ordering::Relaxed);
128 self.download_time
129 .fetch_add(download_time, Ordering::Relaxed);
130 }
131
132 fn increment(&mut self, semaphore: &Semaphore) {
133 semaphore.add_permits(1);
134 self.increments += 1;
135 }
136
137 fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
138 let bytes_downloaded = self.downloaded.fetch_add(0, Ordering::Relaxed);
139 let time_elapsed = self.download_time.fetch_add(0, Ordering::Relaxed);
140 let download_speed = bytes_downloaded
141 .checked_div(time_elapsed)
142 .unwrap_or_default();
143
144 let increased = download_speed > self.previous_download_speed;
145 self.previous_download_speed = download_speed;
146 match self.opt_state {
147 Optimization::Step => {
148 self.increment(semaphore);
149 self.opt_state = Optimization::Accept
150 },
151 Optimization::Accept => {
152 if increased {
154 self.increment(semaphore);
156 }
158 else {
160 self.opt_state = Optimization::Finished;
161 FINISHED_TUNING.store(true, Ordering::Relaxed);
162 if verbose() {
163 eprintln!(
164 "concurrency tuner finished after adding {} steps",
165 self.increments
166 )
167 }
168 return true;
170 }
171 },
172 Optimization::Finished => {},
173 }
174 self.last_tune = std::time::Instant::now();
175 false
177 }
178}
179static INCR: AtomicU8 = AtomicU8::new(0);
180static FINISHED_TUNING: AtomicBool = AtomicBool::new(false);
181static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
182 std::sync::OnceLock::new();
183
184fn get_semaphore() -> &'static (Semaphore, u32) {
185 CONCURRENCY_BUDGET.get_or_init(|| {
186 let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
187 .map(|s| {
188 let budget = s.parse::<usize>().expect("integer");
189 FINISHED_TUNING.store(true, Ordering::Relaxed);
190 budget
191 })
192 .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
193 (Semaphore::new(permits), permits as u32)
194 })
195}
196
197pub(crate) fn get_concurrency_limit() -> u32 {
198 get_semaphore().1
199}
200
201pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
202where
203 F: FnOnce() -> Fut,
204 Fut: Future,
205 Fut::Output: GetSize,
206{
207 let (semaphore, initial_budget) = get_semaphore();
208
209 assert!(requested_budget <= *initial_budget);
211
212 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
215
216 let now = std::time::Instant::now();
217 let res = callable().await;
218
219 if FINISHED_TUNING.load(Ordering::Relaxed) || res.size() == 0 {
220 return res;
221 }
222
223 let duration = now.elapsed().as_millis() as u64;
224 let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
225
226 let Ok(tuner) = permit_store.try_read() else {
227 return res;
228 };
229 tuner.add_stats(res.size(), duration);
231
232 if !tuner.should_tune() {
234 return res;
235 }
236 drop(tuner);
238
239 if (INCR.fetch_add(1, Ordering::Relaxed) % 5) != 0 {
241 return res;
242 }
243 let Ok(mut tuner) = permit_store.try_write() else {
245 return res;
246 };
247 let finished = tuner.tune(semaphore);
248 if finished {
249 drop(_permit_acq);
250 let undo = semaphore.acquire().await.unwrap();
252 std::mem::forget(undo)
253 }
254 res
255}
256
257pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
258where
259 F: FnOnce() -> Fut,
260 Fut: Future,
261{
262 let (semaphore, initial_budget) = get_semaphore();
263
264 assert!(requested_budget <= *initial_budget);
266
267 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
270
271 callable().await
272}
273
274pub struct RuntimeManager {
275 rt: Runtime,
276}
277
278impl RuntimeManager {
279 fn new() -> Self {
280 let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
281 .map(|x| x.parse::<usize>().expect("integer"))
282 .unwrap_or(POOL.current_num_threads().clamp(1, 4));
283
284 if polars_core::config::verbose() {
285 eprintln!("async thread count: {}", n_threads);
286 }
287
288 let rt = Builder::new_multi_thread()
289 .worker_threads(n_threads)
290 .enable_io()
291 .enable_time()
292 .build()
293 .unwrap();
294
295 Self { rt }
296 }
297
298 pub fn block_in_place_on<F>(&self, future: F) -> F::Output
305 where
306 F: Future,
307 {
308 tokio::task::block_in_place(|| self.rt.block_on(future))
309 }
310
311 pub fn block_on<F>(&self, future: F) -> F::Output
314 where
315 F: Future,
316 {
317 self.rt.block_on(future)
318 }
319
320 pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
322 where
323 F: Future + Send + 'static,
324 F::Output: Send + 'static,
325 {
326 self.rt.spawn(future)
327 }
328
329 pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
331 where
332 F: FnOnce() -> R + Send + 'static,
333 R: Send + 'static,
334 {
335 self.rt.spawn_blocking(f)
336 }
337
338 pub async fn spawn_rayon<F, O>(&self, func: F) -> O
342 where
343 F: FnOnce() -> O + Send + Sync + 'static,
344 O: Send + Sync + 'static,
345 {
346 if POOL.current_thread_index().is_some() {
347 tokio::task::block_in_place(func)
351 } else {
352 let (tx, rx) = tokio::sync::oneshot::channel();
353
354 let func = move || {
355 let out = func();
356 let _ = tx.send(out);
358 };
359
360 POOL.spawn(func);
361
362 rx.await.unwrap()
363 }
364 }
365}
366
367static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
368
369pub fn get_runtime() -> &'static RuntimeManager {
370 RUNTIME.deref()
371}