1use std::error::Error;
2use std::future::Future;
3use std::ops::Deref;
4use std::sync::LazyLock;
5
6use polars_core::POOL;
7use polars_core::config::{self, verbose};
8use polars_utils::relaxed_cell::RelaxedCell;
9use tokio::runtime::{Builder, Runtime};
10use tokio::sync::Semaphore;
11
12static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
13pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
14
15static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
18 let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
19 .as_deref()
20 .map(|x| x.parse().expect("integer"))
21 .unwrap_or(64 * 1024 * 1024);
22
23 if config::verbose() {
24 eprintln!("async download_chunk_size: {v}")
25 }
26
27 v
28});
29
30pub(super) fn get_download_chunk_size() -> usize {
31 *DOWNLOAD_CHUNK_SIZE
32}
33
34static UPLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
35 let v: usize = std::env::var("POLARS_UPLOAD_CHUNK_SIZE")
36 .as_deref()
37 .map(|x| x.parse().expect("integer"))
38 .unwrap_or(64 * 1024 * 1024);
39
40 if config::verbose() {
41 eprintln!("async upload_chunk_size: {v}")
42 }
43
44 v
45});
46
47pub(super) fn get_upload_chunk_size() -> usize {
48 *UPLOAD_CHUNK_SIZE
49}
50
51pub trait GetSize {
52 fn size(&self) -> u64;
53}
54
55impl GetSize for bytes::Bytes {
56 fn size(&self) -> u64 {
57 self.len() as u64
58 }
59}
60
61impl<T: GetSize> GetSize for Vec<T> {
62 fn size(&self) -> u64 {
63 self.iter().map(|v| v.size()).sum()
64 }
65}
66
67impl<T: GetSize, E: Error> GetSize for Result<T, E> {
68 fn size(&self) -> u64 {
69 match self {
70 Ok(v) => v.size(),
71 Err(_) => 0,
72 }
73 }
74}
75
76#[cfg(feature = "cloud")]
77pub(crate) struct Size(u64);
78
79#[cfg(feature = "cloud")]
80impl GetSize for Size {
81 fn size(&self) -> u64 {
82 self.0
83 }
84}
85#[cfg(feature = "cloud")]
86impl From<u64> for Size {
87 fn from(value: u64) -> Self {
88 Self(value)
89 }
90}
91
92enum Optimization {
93 Step,
94 Accept,
95 Finished,
96}
97
98struct SemaphoreTuner {
99 previous_download_speed: u64,
100 last_tune: std::time::Instant,
101 downloaded: RelaxedCell<u64>,
102 download_time: RelaxedCell<u64>,
103 opt_state: Optimization,
104 increments: u32,
105}
106
107impl SemaphoreTuner {
108 fn new() -> Self {
109 Self {
110 previous_download_speed: 0,
111 last_tune: std::time::Instant::now(),
112 downloaded: RelaxedCell::from(0),
113 download_time: RelaxedCell::from(0),
114 opt_state: Optimization::Step,
115 increments: 0,
116 }
117 }
118 fn should_tune(&self) -> bool {
119 match self.opt_state {
120 Optimization::Finished => false,
121 _ => self.last_tune.elapsed().as_millis() > 350,
122 }
123 }
124
125 fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
126 self.downloaded.fetch_add(downloaded_bytes);
127 self.download_time.fetch_add(download_time);
128 }
129
130 fn increment(&mut self, semaphore: &Semaphore) {
131 semaphore.add_permits(1);
132 self.increments += 1;
133 }
134
135 fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
136 let bytes_downloaded = self.downloaded.load();
137 let time_elapsed = self.download_time.load();
138 let download_speed = bytes_downloaded
139 .checked_div(time_elapsed)
140 .unwrap_or_default();
141
142 let increased = download_speed > self.previous_download_speed;
143 self.previous_download_speed = download_speed;
144 match self.opt_state {
145 Optimization::Step => {
146 self.increment(semaphore);
147 self.opt_state = Optimization::Accept
148 },
149 Optimization::Accept => {
150 if increased {
152 self.increment(semaphore);
154 }
156 else {
158 self.opt_state = Optimization::Finished;
159 FINISHED_TUNING.store(true);
160 if verbose() {
161 eprintln!(
162 "concurrency tuner finished after adding {} steps",
163 self.increments
164 )
165 }
166 return true;
168 }
169 },
170 Optimization::Finished => {},
171 }
172 self.last_tune = std::time::Instant::now();
173 false
175 }
176}
177static INCR: RelaxedCell<u64> = RelaxedCell::new_u64(0);
178static FINISHED_TUNING: RelaxedCell<bool> = RelaxedCell::new_bool(false);
179static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
180 std::sync::OnceLock::new();
181
182fn get_semaphore() -> &'static (Semaphore, u32) {
183 CONCURRENCY_BUDGET.get_or_init(|| {
184 let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
185 .map(|s| {
186 let budget = s.parse::<usize>().expect("integer");
187 FINISHED_TUNING.store(true);
188 budget
189 })
190 .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
191 (Semaphore::new(permits), permits as u32)
192 })
193}
194
195pub(crate) fn get_concurrency_limit() -> u32 {
196 get_semaphore().1
197}
198
199pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
200where
201 F: FnOnce() -> Fut,
202 Fut: Future,
203 Fut::Output: GetSize,
204{
205 let (semaphore, initial_budget) = get_semaphore();
206
207 assert!(requested_budget <= *initial_budget);
209
210 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
213
214 let now = std::time::Instant::now();
215 let res = callable().await;
216
217 if FINISHED_TUNING.load() || res.size() == 0 {
218 return res;
219 }
220
221 let duration = now.elapsed().as_millis() as u64;
222 let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
223
224 let Ok(tuner) = permit_store.try_read() else {
225 return res;
226 };
227 tuner.add_stats(res.size(), duration);
229
230 if !tuner.should_tune() {
232 return res;
233 }
234 drop(tuner);
236
237 if !INCR.fetch_add(1).is_multiple_of(5) {
239 return res;
240 }
241 let Ok(mut tuner) = permit_store.try_write() else {
243 return res;
244 };
245 let finished = tuner.tune(semaphore);
246 if finished {
247 drop(_permit_acq);
248 let undo = semaphore.acquire().await.unwrap();
250 std::mem::forget(undo)
251 }
252 res
253}
254
255pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
256where
257 F: FnOnce() -> Fut,
258 Fut: Future,
259{
260 let (semaphore, initial_budget) = get_semaphore();
261
262 assert!(requested_budget <= *initial_budget);
264
265 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
268
269 callable().await
270}
271
272pub struct RuntimeManager {
273 rt: Runtime,
274}
275
276impl RuntimeManager {
277 fn new() -> Self {
278 let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
279 .map(|x| x.parse::<usize>().expect("integer"))
280 .unwrap_or(POOL.current_num_threads().clamp(1, 4));
281
282 if polars_core::config::verbose() {
283 eprintln!("async thread count: {n_threads}");
284 }
285
286 let rt = Builder::new_multi_thread()
287 .worker_threads(n_threads)
288 .enable_io()
289 .enable_time()
290 .build()
291 .unwrap();
292
293 Self { rt }
294 }
295
296 pub fn block_in_place_on<F>(&self, future: F) -> F::Output
302 where
303 F: Future,
304 {
305 tokio::task::block_in_place(|| self.rt.block_on(future))
306 }
307
308 pub fn block_on<F>(&self, future: F) -> F::Output
311 where
312 F: Future,
313 {
314 self.rt.block_on(future)
315 }
316
317 pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
319 where
320 F: Future + Send + 'static,
321 F::Output: Send + 'static,
322 {
323 self.rt.spawn(future)
324 }
325
326 pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
328 where
329 F: FnOnce() -> R + Send + 'static,
330 R: Send + 'static,
331 {
332 self.rt.spawn_blocking(f)
333 }
334
335 pub async fn spawn_rayon<F, O>(&self, func: F) -> O
339 where
340 F: FnOnce() -> O + Send + Sync + 'static,
341 O: Send + Sync + 'static,
342 {
343 if POOL.current_thread_index().is_some() {
344 tokio::task::block_in_place(func)
348 } else {
349 let (tx, rx) = tokio::sync::oneshot::channel();
350
351 let func = move || {
352 let out = func();
353 let _ = tx.send(out);
355 };
356
357 POOL.spawn(func);
358
359 rx.await.unwrap()
360 }
361 }
362}
363
364static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
365
366pub fn get_runtime() -> &'static RuntimeManager {
367 RUNTIME.deref()
368}