1use std::error::Error;
2use std::future::Future;
3use std::ops::Deref;
4use std::sync::LazyLock;
5
6use polars_core::POOL;
7use polars_core::config::{self, verbose};
8use polars_utils::relaxed_cell::RelaxedCell;
9use tokio::runtime::{Builder, Runtime};
10use tokio::sync::Semaphore;
11
12static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
13pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
14
15static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
18 let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
19 .as_deref()
20 .map(|x| x.parse().expect("integer"))
21 .unwrap_or(64 * 1024 * 1024);
22
23 if config::verbose() {
24 eprintln!("async download_chunk_size: {v}")
25 }
26
27 v
28});
29
30pub(super) fn get_download_chunk_size() -> usize {
31 *DOWNLOAD_CHUNK_SIZE
32}
33
34pub trait GetSize {
35 fn size(&self) -> u64;
36}
37
38impl GetSize for bytes::Bytes {
39 fn size(&self) -> u64 {
40 self.len() as u64
41 }
42}
43
44impl<T: GetSize> GetSize for Vec<T> {
45 fn size(&self) -> u64 {
46 self.iter().map(|v| v.size()).sum()
47 }
48}
49
50impl<T: GetSize, E: Error> GetSize for Result<T, E> {
51 fn size(&self) -> u64 {
52 match self {
53 Ok(v) => v.size(),
54 Err(_) => 0,
55 }
56 }
57}
58
59#[cfg(feature = "cloud")]
60pub(crate) struct Size(u64);
61
62#[cfg(feature = "cloud")]
63impl GetSize for Size {
64 fn size(&self) -> u64 {
65 self.0
66 }
67}
68#[cfg(feature = "cloud")]
69impl From<u64> for Size {
70 fn from(value: u64) -> Self {
71 Self(value)
72 }
73}
74
75enum Optimization {
76 Step,
77 Accept,
78 Finished,
79}
80
81struct SemaphoreTuner {
82 previous_download_speed: u64,
83 last_tune: std::time::Instant,
84 downloaded: RelaxedCell<u64>,
85 download_time: RelaxedCell<u64>,
86 opt_state: Optimization,
87 increments: u32,
88}
89
90impl SemaphoreTuner {
91 fn new() -> Self {
92 Self {
93 previous_download_speed: 0,
94 last_tune: std::time::Instant::now(),
95 downloaded: RelaxedCell::from(0),
96 download_time: RelaxedCell::from(0),
97 opt_state: Optimization::Step,
98 increments: 0,
99 }
100 }
101 fn should_tune(&self) -> bool {
102 match self.opt_state {
103 Optimization::Finished => false,
104 _ => self.last_tune.elapsed().as_millis() > 350,
105 }
106 }
107
108 fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
109 self.downloaded.fetch_add(downloaded_bytes);
110 self.download_time.fetch_add(download_time);
111 }
112
113 fn increment(&mut self, semaphore: &Semaphore) {
114 semaphore.add_permits(1);
115 self.increments += 1;
116 }
117
118 fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
119 let bytes_downloaded = self.downloaded.load();
120 let time_elapsed = self.download_time.load();
121 let download_speed = bytes_downloaded
122 .checked_div(time_elapsed)
123 .unwrap_or_default();
124
125 let increased = download_speed > self.previous_download_speed;
126 self.previous_download_speed = download_speed;
127 match self.opt_state {
128 Optimization::Step => {
129 self.increment(semaphore);
130 self.opt_state = Optimization::Accept
131 },
132 Optimization::Accept => {
133 if increased {
135 self.increment(semaphore);
137 }
139 else {
141 self.opt_state = Optimization::Finished;
142 FINISHED_TUNING.store(true);
143 if verbose() {
144 eprintln!(
145 "concurrency tuner finished after adding {} steps",
146 self.increments
147 )
148 }
149 return true;
151 }
152 },
153 Optimization::Finished => {},
154 }
155 self.last_tune = std::time::Instant::now();
156 false
158 }
159}
160static INCR: RelaxedCell<u64> = RelaxedCell::new_u64(0);
161static FINISHED_TUNING: RelaxedCell<bool> = RelaxedCell::new_bool(false);
162static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
163 std::sync::OnceLock::new();
164
165fn get_semaphore() -> &'static (Semaphore, u32) {
166 CONCURRENCY_BUDGET.get_or_init(|| {
167 let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
168 .map(|s| {
169 let budget = s.parse::<usize>().expect("integer");
170 FINISHED_TUNING.store(true);
171 budget
172 })
173 .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
174 (Semaphore::new(permits), permits as u32)
175 })
176}
177
178pub(crate) fn get_concurrency_limit() -> u32 {
179 get_semaphore().1
180}
181
182pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
183where
184 F: FnOnce() -> Fut,
185 Fut: Future,
186 Fut::Output: GetSize,
187{
188 let (semaphore, initial_budget) = get_semaphore();
189
190 assert!(requested_budget <= *initial_budget);
192
193 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
196
197 let now = std::time::Instant::now();
198 let res = callable().await;
199
200 if FINISHED_TUNING.load() || res.size() == 0 {
201 return res;
202 }
203
204 let duration = now.elapsed().as_millis() as u64;
205 let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
206
207 let Ok(tuner) = permit_store.try_read() else {
208 return res;
209 };
210 tuner.add_stats(res.size(), duration);
212
213 if !tuner.should_tune() {
215 return res;
216 }
217 drop(tuner);
219
220 if !INCR.fetch_add(1).is_multiple_of(5) {
222 return res;
223 }
224 let Ok(mut tuner) = permit_store.try_write() else {
226 return res;
227 };
228 let finished = tuner.tune(semaphore);
229 if finished {
230 drop(_permit_acq);
231 let undo = semaphore.acquire().await.unwrap();
233 std::mem::forget(undo)
234 }
235 res
236}
237
238pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
239where
240 F: FnOnce() -> Fut,
241 Fut: Future,
242{
243 let (semaphore, initial_budget) = get_semaphore();
244
245 assert!(requested_budget <= *initial_budget);
247
248 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
251
252 callable().await
253}
254
255pub struct RuntimeManager {
256 rt: Runtime,
257}
258
259impl RuntimeManager {
260 fn new() -> Self {
261 let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
262 .map(|x| x.parse::<usize>().expect("integer"))
263 .unwrap_or(POOL.current_num_threads().clamp(1, 4));
264
265 if polars_core::config::verbose() {
266 eprintln!("async thread count: {n_threads}");
267 }
268
269 let rt = Builder::new_multi_thread()
270 .worker_threads(n_threads)
271 .enable_io()
272 .enable_time()
273 .build()
274 .unwrap();
275
276 Self { rt }
277 }
278
279 pub fn block_in_place_on<F>(&self, future: F) -> F::Output
285 where
286 F: Future,
287 {
288 tokio::task::block_in_place(|| self.rt.block_on(future))
289 }
290
291 pub fn block_on<F>(&self, future: F) -> F::Output
294 where
295 F: Future,
296 {
297 self.rt.block_on(future)
298 }
299
300 pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
302 where
303 F: Future + Send + 'static,
304 F::Output: Send + 'static,
305 {
306 self.rt.spawn(future)
307 }
308
309 pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
311 where
312 F: FnOnce() -> R + Send + 'static,
313 R: Send + 'static,
314 {
315 self.rt.spawn_blocking(f)
316 }
317
318 pub async fn spawn_rayon<F, O>(&self, func: F) -> O
322 where
323 F: FnOnce() -> O + Send + Sync + 'static,
324 O: Send + Sync + 'static,
325 {
326 if POOL.current_thread_index().is_some() {
327 tokio::task::block_in_place(func)
331 } else {
332 let (tx, rx) = tokio::sync::oneshot::channel();
333
334 let func = move || {
335 let out = func();
336 let _ = tx.send(out);
338 };
339
340 POOL.spawn(func);
341
342 rx.await.unwrap()
343 }
344 }
345}
346
347static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
348
349pub fn get_runtime() -> &'static RuntimeManager {
350 RUNTIME.deref()
351}