1use std::error::Error;
2use std::future::Future;
3use std::sync::LazyLock;
4
5use polars_buffer::Buffer;
6use polars_core::config::{self, verbose};
7use polars_core::runtime::RAYON;
8use polars_utils::relaxed_cell::RelaxedCell;
9use tokio::sync::Semaphore;
10
11static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
12pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
13
14static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
17 let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
18 .as_deref()
19 .map(|x| x.parse().expect("integer"))
20 .unwrap_or(64 * 1024 * 1024);
21
22 if config::verbose() {
23 eprintln!("async download_chunk_size: {v}")
24 }
25
26 v
27});
28
29pub(super) fn get_download_chunk_size() -> usize {
30 *DOWNLOAD_CHUNK_SIZE
31}
32
33pub trait GetSize {
34 fn size(&self) -> u64;
35}
36
37impl GetSize for Buffer<u8> {
38 fn size(&self) -> u64 {
39 self.len() as u64
40 }
41}
42
43impl<T: GetSize> GetSize for Vec<T> {
44 fn size(&self) -> u64 {
45 self.iter().map(|v| v.size()).sum()
46 }
47}
48
49impl<T: GetSize, E: Error> GetSize for Result<T, E> {
50 fn size(&self) -> u64 {
51 match self {
52 Ok(v) => v.size(),
53 Err(_) => 0,
54 }
55 }
56}
57
58#[cfg(feature = "cloud")]
59pub(crate) struct Size(u64);
60
61#[cfg(feature = "cloud")]
62impl GetSize for Size {
63 fn size(&self) -> u64 {
64 self.0
65 }
66}
67#[cfg(feature = "cloud")]
68impl From<u64> for Size {
69 fn from(value: u64) -> Self {
70 Self(value)
71 }
72}
73
74enum Optimization {
75 Step,
76 Accept,
77 Finished,
78}
79
80struct SemaphoreTuner {
81 previous_download_speed: u64,
82 last_tune: std::time::Instant,
83 downloaded: RelaxedCell<u64>,
84 download_time: RelaxedCell<u64>,
85 opt_state: Optimization,
86 increments: u32,
87}
88
89impl SemaphoreTuner {
90 fn new() -> Self {
91 Self {
92 previous_download_speed: 0,
93 last_tune: std::time::Instant::now(),
94 downloaded: RelaxedCell::from(0),
95 download_time: RelaxedCell::from(0),
96 opt_state: Optimization::Step,
97 increments: 0,
98 }
99 }
100 fn should_tune(&self) -> bool {
101 match self.opt_state {
102 Optimization::Finished => false,
103 _ => self.last_tune.elapsed().as_millis() > 350,
104 }
105 }
106
107 fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
108 self.downloaded.fetch_add(downloaded_bytes);
109 self.download_time.fetch_add(download_time);
110 }
111
112 fn increment(&mut self, semaphore: &Semaphore) {
113 semaphore.add_permits(1);
114 self.increments += 1;
115 }
116
117 fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
118 let bytes_downloaded = self.downloaded.load();
119 let time_elapsed = self.download_time.load();
120 let download_speed = bytes_downloaded
121 .checked_div(time_elapsed)
122 .unwrap_or_default();
123
124 let increased = download_speed > self.previous_download_speed;
125 self.previous_download_speed = download_speed;
126 match self.opt_state {
127 Optimization::Step => {
128 self.increment(semaphore);
129 self.opt_state = Optimization::Accept
130 },
131 Optimization::Accept => {
132 if increased {
134 self.increment(semaphore);
136 }
138 else {
140 self.opt_state = Optimization::Finished;
141 FINISHED_TUNING.store(true);
142 if verbose() {
143 eprintln!(
144 "concurrency tuner finished after adding {} steps",
145 self.increments
146 )
147 }
148 return true;
150 }
151 },
152 Optimization::Finished => {},
153 }
154 self.last_tune = std::time::Instant::now();
155 false
157 }
158}
159static INCR: RelaxedCell<u64> = RelaxedCell::new_u64(0);
160static FINISHED_TUNING: RelaxedCell<bool> = RelaxedCell::new_bool(false);
161static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
162 std::sync::OnceLock::new();
163
164fn get_semaphore() -> &'static (Semaphore, u32) {
165 CONCURRENCY_BUDGET.get_or_init(|| {
166 let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
167 .map(|s| {
168 let budget = s.parse::<usize>().expect("integer");
169 FINISHED_TUNING.store(true);
170 budget
171 })
172 .unwrap_or_else(|_| std::cmp::max(RAYON.current_num_threads(), MAX_BUDGET_PER_REQUEST));
173 (Semaphore::new(permits), permits as u32)
174 })
175}
176
177pub(crate) fn get_concurrency_limit() -> u32 {
178 get_semaphore().1
179}
180
181pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
182where
183 F: FnOnce() -> Fut,
184 Fut: Future,
185 Fut::Output: GetSize,
186{
187 let (semaphore, initial_budget) = get_semaphore();
188
189 assert!(requested_budget <= *initial_budget);
191
192 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
195
196 let now = std::time::Instant::now();
197 let res = callable().await;
198
199 if FINISHED_TUNING.load() || res.size() == 0 {
200 return res;
201 }
202
203 let duration = now.elapsed().as_millis() as u64;
204 let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
205
206 let Ok(tuner) = permit_store.try_read() else {
207 return res;
208 };
209 tuner.add_stats(res.size(), duration);
211
212 if !tuner.should_tune() {
214 return res;
215 }
216 drop(tuner);
218
219 if !INCR.fetch_add(1).is_multiple_of(5) {
221 return res;
222 }
223 let Ok(mut tuner) = permit_store.try_write() else {
225 return res;
226 };
227 let finished = tuner.tune(semaphore);
228 if finished {
229 drop(_permit_acq);
230 let undo = semaphore.acquire().await.unwrap();
232 std::mem::forget(undo)
233 }
234 res
235}
236
237pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
238where
239 F: FnOnce() -> Fut,
240 Fut: Future,
241{
242 let (semaphore, initial_budget) = get_semaphore();
243
244 assert!(requested_budget <= *initial_budget);
246
247 let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
250
251 callable().await
252}