1use std::error::Error;
2use std::future::Future;
3use std::ops::Deref;
4use std::sync::LazyLock;
5
6use polars_buffer::Buffer;
7use polars_core::POOL;
8use polars_core::config::{self, verbose};
9use polars_utils::relaxed_cell::RelaxedCell;
10use tokio::runtime::{Builder, Runtime};
11use tokio::sync::Semaphore;
12
13static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
14pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
15
16static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
19 let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
20 .as_deref()
21 .map(|x| x.parse().expect("integer"))
22 .unwrap_or(64 * 1024 * 1024);
23
24 if config::verbose() {
25 eprintln!("async download_chunk_size: {v}")
26 }
27
28 v
29});
30
31pub(super) fn get_download_chunk_size() -> usize {
32 *DOWNLOAD_CHUNK_SIZE
33}
34
35pub trait GetSize {
36 fn size(&self) -> u64;
37}
38
39impl GetSize for Buffer<u8> {
40 fn size(&self) -> u64 {
41 self.len() as u64
42 }
43}
44
45impl<T: GetSize> GetSize for Vec<T> {
46 fn size(&self) -> u64 {
47 self.iter().map(|v| v.size()).sum()
48 }
49}
50
51impl<T: GetSize, E: Error> GetSize for Result<T, E> {
52 fn size(&self) -> u64 {
53 match self {
54 Ok(v) => v.size(),
55 Err(_) => 0,
56 }
57 }
58}
59
60#[cfg(feature = "cloud")]
61pub(crate) struct Size(u64);
62
63#[cfg(feature = "cloud")]
64impl GetSize for Size {
65 fn size(&self) -> u64 {
66 self.0
67 }
68}
69#[cfg(feature = "cloud")]
70impl From<u64> for Size {
71 fn from(value: u64) -> Self {
72 Self(value)
73 }
74}
75
76enum Optimization {
77 Step,
78 Accept,
79 Finished,
80}
81
82struct SemaphoreTuner {
83 previous_download_speed: u64,
84 last_tune: std::time::Instant,
85 downloaded: RelaxedCell<u64>,
86 download_time: RelaxedCell<u64>,
87 opt_state: Optimization,
88 increments: u32,
89}
90
91impl SemaphoreTuner {
92 fn new() -> Self {
93 Self {
94 previous_download_speed: 0,
95 last_tune: std::time::Instant::now(),
96 downloaded: RelaxedCell::from(0),
97 download_time: RelaxedCell::from(0),
98 opt_state: Optimization::Step,
99 increments: 0,
100 }
101 }
102 fn should_tune(&self) -> bool {
103 match self.opt_state {
104 Optimization::Finished => false,
105 _ => self.last_tune.elapsed().as_millis() > 350,
106 }
107 }
108
109 fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
110 self.downloaded.fetch_add(downloaded_bytes);
111 self.download_time.fetch_add(download_time);
112 }
113
114 fn increment(&mut self, semaphore: &Semaphore) {
115 semaphore.add_permits(1);
116 self.increments += 1;
117 }
118
119 fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
120 let bytes_downloaded = self.downloaded.load();
121 let time_elapsed = self.download_time.load();
122 let download_speed = bytes_downloaded
123 .checked_div(time_elapsed)
124 .unwrap_or_default();
125
126 let increased = download_speed > self.previous_download_speed;
127 self.previous_download_speed = download_speed;
128 match self.opt_state {
129 Optimization::Step => {
130 self.increment(semaphore);
131 self.opt_state = Optimization::Accept
132 },
133 Optimization::Accept => {
134 if increased {
136 self.increment(semaphore);
138 }
140 else {
142 self.opt_state = Optimization::Finished;
143 FINISHED_TUNING.store(true);
144 if verbose() {
145 eprintln!(
146 "concurrency tuner finished after adding {} steps",
147 self.increments
148 )
149 }
150 return true;
152 }
153 },
154 Optimization::Finished => {},
155 }
156 self.last_tune = std::time::Instant::now();
157 false
159 }
160}
161static INCR: RelaxedCell<u64> = RelaxedCell::new_u64(0);
162static FINISHED_TUNING: RelaxedCell<bool> = RelaxedCell::new_bool(false);
163static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
164 std::sync::OnceLock::new();
165
166fn get_semaphore() -> &'static (Semaphore, u32) {
167 CONCURRENCY_BUDGET.get_or_init(|| {
168 let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
169 .map(|s| {
170 let budget = s.parse::<usize>().expect("integer");
171 FINISHED_TUNING.store(true);
172 budget
173 })
174 .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
175 (Semaphore::new(permits), permits as u32)
176 })
177}
178
179pub(crate) fn get_concurrency_limit() -> u32 {
180 get_semaphore().1
181}
182
183pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
184where
185 F: FnOnce() -> Fut,
186 Fut: Future,
187 Fut::Output: GetSize,
188{
189 let (semaphore, initial_budget) = get_semaphore();
190
191 assert!(requested_budget <= *initial_budget);
193
194 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
197
198 let now = std::time::Instant::now();
199 let res = callable().await;
200
201 if FINISHED_TUNING.load() || res.size() == 0 {
202 return res;
203 }
204
205 let duration = now.elapsed().as_millis() as u64;
206 let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
207
208 let Ok(tuner) = permit_store.try_read() else {
209 return res;
210 };
211 tuner.add_stats(res.size(), duration);
213
214 if !tuner.should_tune() {
216 return res;
217 }
218 drop(tuner);
220
221 if !INCR.fetch_add(1).is_multiple_of(5) {
223 return res;
224 }
225 let Ok(mut tuner) = permit_store.try_write() else {
227 return res;
228 };
229 let finished = tuner.tune(semaphore);
230 if finished {
231 drop(_permit_acq);
232 let undo = semaphore.acquire().await.unwrap();
234 std::mem::forget(undo)
235 }
236 res
237}
238
239pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
240where
241 F: FnOnce() -> Fut,
242 Fut: Future,
243{
244 let (semaphore, initial_budget) = get_semaphore();
245
246 assert!(requested_budget <= *initial_budget);
248
249 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
252
253 callable().await
254}
255
256pub struct RuntimeManager {
257 rt: Runtime,
258}
259
260impl RuntimeManager {
261 fn new() -> Self {
262 let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
263 .map(|x| x.parse::<usize>().expect("integer"))
264 .unwrap_or(usize::min(POOL.current_num_threads(), 32));
265
266 let max_blocking = std::env::var("POLARS_MAX_BLOCKING_THREAD_COUNT")
267 .map(|x| x.parse::<usize>().expect("integer"))
268 .unwrap_or(512);
269
270 if polars_core::config::verbose() {
271 eprintln!("async thread count: {n_threads}");
272 eprintln!("blocking thread count: {max_blocking}");
273 }
274
275 let rt = Builder::new_multi_thread()
276 .worker_threads(n_threads)
277 .max_blocking_threads(max_blocking)
278 .enable_io()
279 .enable_time()
280 .build()
281 .unwrap();
282
283 Self { rt }
284 }
285
286 pub fn block_in_place_on<F>(&self, future: F) -> F::Output
292 where
293 F: Future,
294 {
295 tokio::task::block_in_place(|| self.rt.block_on(future))
296 }
297
298 pub fn block_on<F>(&self, future: F) -> F::Output
301 where
302 F: Future,
303 {
304 self.rt.block_on(future)
305 }
306
307 pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
309 where
310 F: Future + Send + 'static,
311 F::Output: Send + 'static,
312 {
313 self.rt.spawn(future)
314 }
315
316 pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
318 where
319 F: FnOnce() -> R + Send + 'static,
320 R: Send + 'static,
321 {
322 self.rt.spawn_blocking(f)
323 }
324
325 pub async fn spawn_rayon<F, O>(&self, func: F) -> O
329 where
330 F: FnOnce() -> O + Send + Sync + 'static,
331 O: Send + Sync + 'static,
332 {
333 if POOL.current_thread_index().is_some() {
334 tokio::task::block_in_place(func)
338 } else {
339 let (tx, rx) = tokio::sync::oneshot::channel();
340
341 let func = move || {
342 let out = func();
343 let _ = tx.send(out);
345 };
346
347 POOL.spawn(func);
348
349 rx.await.unwrap()
350 }
351 }
352}
353
354static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
355
356pub fn get_runtime() -> &'static RuntimeManager {
357 RUNTIME.deref()
358}