1#[cfg(feature = "aws")]
2use std::io::Read;
3#[cfg(feature = "aws")]
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::LazyLock;
7
8#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
9use object_store::ClientOptions;
10#[cfg(feature = "aws")]
11use object_store::aws::AmazonS3Builder;
12#[cfg(feature = "aws")]
13pub use object_store::aws::AmazonS3ConfigKey;
14#[cfg(feature = "azure")]
15pub use object_store::azure::AzureConfigKey;
16#[cfg(feature = "azure")]
17use object_store::azure::MicrosoftAzureBuilder;
18#[cfg(feature = "gcp")]
19use object_store::gcp::GoogleCloudStorageBuilder;
20#[cfg(feature = "gcp")]
21pub use object_store::gcp::GoogleConfigKey;
22#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
23use object_store::{BackoffConfig, RetryConfig};
24use polars_error::*;
25#[cfg(feature = "aws")]
26use polars_utils::cache::LruCache;
27use polars_utils::plpath::{CloudScheme, PlPathRef};
28#[cfg(feature = "http")]
29use reqwest::header::HeaderMap;
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32
33#[cfg(feature = "cloud")]
34use super::credential_provider::PlCredentialProvider;
35#[cfg(feature = "file_cache")]
36use crate::file_cache::get_env_file_cache_ttl;
37#[cfg(feature = "aws")]
38use crate::pl_async::with_concurrency_budget;
39
40#[cfg(feature = "aws")]
41static BUCKET_REGION: LazyLock<
42 std::sync::Mutex<LruCache<polars_utils::pl_str::PlSmallStr, polars_utils::pl_str::PlSmallStr>>,
43> = LazyLock::new(|| std::sync::Mutex::new(LruCache::with_capacity(32)));
44
45#[allow(dead_code)]
52type Configs<T> = Vec<(T, String)>;
53
54#[derive(Clone, Debug, PartialEq, Hash, Eq)]
55#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
57pub(crate) enum CloudConfig {
58 #[cfg(feature = "aws")]
59 Aws(
60 #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
61 Configs<AmazonS3ConfigKey>,
62 ),
63 #[cfg(feature = "azure")]
64 Azure(
65 #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
66 Configs<AzureConfigKey>,
67 ),
68 #[cfg(feature = "gcp")]
69 Gcp(
70 #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
71 Configs<GoogleConfigKey>,
72 ),
73 #[cfg(feature = "http")]
74 Http { headers: Vec<(String, String)> },
75}
76
77#[derive(Clone, Debug, PartialEq, Hash, Eq)]
78#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
79#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
80pub struct CloudOptions {
82 pub max_retries: usize,
83 #[cfg(feature = "file_cache")]
84 pub file_cache_ttl: u64,
85 pub(crate) config: Option<CloudConfig>,
86 #[cfg(feature = "cloud")]
87 pub(crate) credential_provider: Option<PlCredentialProvider>,
90}
91
92impl Default for CloudOptions {
93 fn default() -> Self {
94 Self::default_static_ref().clone()
95 }
96}
97
98impl CloudOptions {
99 pub fn default_static_ref() -> &'static Self {
100 static DEFAULT: LazyLock<CloudOptions> = LazyLock::new(|| CloudOptions {
101 max_retries: 2,
102 #[cfg(feature = "file_cache")]
103 file_cache_ttl: get_env_file_cache_ttl(),
104 config: None,
105 #[cfg(feature = "cloud")]
106 credential_provider: None,
107 });
108
109 &DEFAULT
110 }
111}
112
113#[cfg(feature = "http")]
114pub(crate) fn try_build_http_header_map_from_items_slice<S: AsRef<str>>(
115 headers: &[(S, S)],
116) -> PolarsResult<HeaderMap> {
117 use reqwest::header::{HeaderName, HeaderValue};
118
119 let mut map = HeaderMap::with_capacity(headers.len());
120 for (k, v) in headers {
121 let (k, v) = (k.as_ref(), v.as_ref());
122 map.insert(
123 HeaderName::from_str(k).map_err(to_compute_err)?,
124 HeaderValue::from_str(v).map_err(to_compute_err)?,
125 );
126 }
127
128 Ok(map)
129}
130
131#[allow(dead_code)]
132fn parse_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
134 config: I,
135) -> PolarsResult<Configs<T>>
136where
137 T: FromStr + Eq + std::hash::Hash,
138{
139 Ok(config
140 .into_iter()
141 .filter_map(|(key, val)| {
143 T::from_str(key.as_ref().to_ascii_lowercase().as_str())
144 .ok()
145 .map(|typed_key| (typed_key, val.into()))
146 })
147 .collect::<Configs<T>>())
148}
149
150#[derive(Debug, Clone, PartialEq)]
151pub enum CloudType {
152 Aws,
153 Azure,
154 File,
155 Gcp,
157 Http,
158 Hf,
160}
161
162impl CloudType {
163 pub fn from_cloud_scheme(scheme: &CloudScheme) -> Self {
164 match scheme {
165 CloudScheme::Abfs
166 | CloudScheme::Abfss
167 | CloudScheme::Adl
168 | CloudScheme::Az
169 | CloudScheme::Azure => Self::Azure,
170
171 CloudScheme::File | CloudScheme::FileNoHostname => Self::File,
172
173 CloudScheme::Gcs | CloudScheme::Gs => Self::Gcp,
174
175 CloudScheme::Hf => Self::Hf,
176
177 CloudScheme::Http | CloudScheme::Https => Self::Http,
178
179 CloudScheme::S3 | CloudScheme::S3a => Self::Aws,
180 }
181 }
182}
183
184#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
185fn get_retry_config(max_retries: usize) -> RetryConfig {
186 RetryConfig {
187 backoff: BackoffConfig::default(),
188 max_retries,
189 retry_timeout: std::time::Duration::from_secs(10),
190 }
191}
192
193#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
194pub(super) fn get_client_options() -> ClientOptions {
195 ClientOptions::new()
196 .with_timeout_disabled()
200 .with_connect_timeout_disabled()
202 .with_allow_http(true)
203}
204
205#[cfg(feature = "aws")]
206fn read_config(
207 builder: &mut AmazonS3Builder,
208 items: &[(&Path, &[(&str, AmazonS3ConfigKey)])],
209) -> Option<()> {
210 use crate::path_utils::resolve_homedir;
211
212 for (path, keys) in items {
213 if keys
214 .iter()
215 .all(|(_, key)| builder.get_config_value(key).is_some())
216 {
217 continue;
218 }
219
220 let mut config = std::fs::File::open(resolve_homedir(path)).ok()?;
221 let mut buf = vec![];
222 config.read_to_end(&mut buf).ok()?;
223 let content = std::str::from_utf8(buf.as_ref()).ok()?;
224
225 for (pattern, key) in keys.iter() {
226 if builder.get_config_value(key).is_none() {
227 let reg = polars_utils::regex_cache::compile_regex(pattern).unwrap();
228 let cap = reg.captures(content)?;
229 let m = cap.get(1)?;
230 let parsed = m.as_str();
231 *builder = std::mem::take(builder).with_config(*key, parsed);
232 }
233 }
234 }
235 Some(())
236}
237
238impl CloudOptions {
239 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
241 self.max_retries = max_retries;
242 self
243 }
244
245 #[cfg(feature = "cloud")]
246 pub fn with_credential_provider(
247 mut self,
248 credential_provider: Option<PlCredentialProvider>,
249 ) -> Self {
250 self.credential_provider = credential_provider;
251 self
252 }
253
254 #[cfg(feature = "aws")]
256 pub fn with_aws<I: IntoIterator<Item = (AmazonS3ConfigKey, impl Into<String>)>>(
257 mut self,
258 configs: I,
259 ) -> Self {
260 self.config = Some(CloudConfig::Aws(
261 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
262 ));
263 self
264 }
265
266 #[cfg(feature = "aws")]
268 pub async fn build_aws(
269 &self,
270 url: &str,
271 clear_cached_credentials: bool,
272 ) -> PolarsResult<impl object_store::ObjectStore> {
273 use super::credential_provider::IntoCredentialProvider;
274
275 let opt_credential_provider =
276 self.initialized_credential_provider(clear_cached_credentials)?;
277
278 let mut builder = AmazonS3Builder::from_env()
279 .with_client_options(get_client_options())
280 .with_url(url);
281
282 if let Some(credential_provider) = &opt_credential_provider {
283 let storage_update_options = parse_untyped_config::<AmazonS3ConfigKey, _>(
284 credential_provider
285 .storage_update_options()?
286 .into_iter()
287 .map(|(k, v)| (k, v.to_string())),
288 )?;
289
290 for (key, value) in storage_update_options {
291 builder = builder.with_config(key, value);
292 }
293 }
294
295 read_config(
296 &mut builder,
297 &[(
298 Path::new("~/.aws/config"),
299 &[("region\\s*=\\s*([^\r\n]*)", AmazonS3ConfigKey::Region)],
300 )],
301 );
302
303 read_config(
304 &mut builder,
305 &[(
306 Path::new("~/.aws/credentials"),
307 &[
308 (
309 "aws_access_key_id\\s*=\\s*([^\\r\\n]*)",
310 AmazonS3ConfigKey::AccessKeyId,
311 ),
312 (
313 "aws_secret_access_key\\s*=\\s*([^\\r\\n]*)",
314 AmazonS3ConfigKey::SecretAccessKey,
315 ),
316 (
317 "aws_session_token\\s*=\\s*([^\\r\\n]*)",
318 AmazonS3ConfigKey::Token,
319 ),
320 ],
321 )],
322 );
323
324 if let Some(options) = &self.config {
325 let CloudConfig::Aws(options) = options else {
326 panic!("impl error: cloud type mismatch")
327 };
328 for (key, value) in options {
329 builder = builder.with_config(*key, value);
330 }
331 }
332
333 if builder
334 .get_config_value(&AmazonS3ConfigKey::DefaultRegion)
335 .is_none()
336 && builder
337 .get_config_value(&AmazonS3ConfigKey::Region)
338 .is_none()
339 {
340 let bucket = crate::cloud::CloudLocation::new(PlPathRef::new(url), false)?.bucket;
341 let region = {
342 let mut bucket_region = BUCKET_REGION.lock().unwrap();
343 bucket_region.get(bucket.as_str()).cloned()
344 };
345
346 match region {
347 Some(region) => {
348 builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str())
349 },
350 None => {
351 if builder
352 .get_config_value(&AmazonS3ConfigKey::Endpoint)
353 .is_some()
354 {
355 builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1");
358 } else {
359 polars_warn!(
360 "'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning."
361 );
362 let result = with_concurrency_budget(1, || async {
363 reqwest::Client::builder()
364 .build()
365 .unwrap()
366 .head(format!("https://{bucket}.s3.amazonaws.com"))
367 .send()
368 .await
369 .map_err(to_compute_err)
370 })
371 .await?;
372 if let Some(region) = result.headers().get("x-amz-bucket-region") {
373 let region =
374 std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?;
375 let mut bucket_region = BUCKET_REGION.lock().unwrap();
376 bucket_region.insert(bucket, region.into());
377 builder = builder.with_config(AmazonS3ConfigKey::Region, region)
378 }
379 }
380 },
381 };
382 };
383
384 let builder = builder.with_retry(get_retry_config(self.max_retries));
385
386 let opt_credential_provider = match opt_credential_provider {
387 #[cfg(feature = "python")]
388 Some(PlCredentialProvider::Python(object)) => {
389 if pyo3::Python::with_gil(|py| {
390 let Ok(func_object) = object
391 .unwrap_as_provider_ref()
392 .getattr(py, "_can_use_as_provider")
393 else {
394 return PolarsResult::Ok(true);
395 };
396
397 Ok(func_object.call0(py)?.extract::<bool>(py).unwrap())
398 })? {
399 Some(PlCredentialProvider::Python(object))
400 } else {
401 None
402 }
403 },
404
405 v => v,
406 };
407
408 let builder = if let Some(credential_provider) = opt_credential_provider {
409 builder.with_credentials(credential_provider.into_aws_provider())
410 } else {
411 builder
412 };
413
414 let out = builder.build()?;
415
416 Ok(out)
417 }
418
419 #[cfg(feature = "azure")]
421 pub fn with_azure<I: IntoIterator<Item = (AzureConfigKey, impl Into<String>)>>(
422 mut self,
423 configs: I,
424 ) -> Self {
425 self.config = Some(CloudConfig::Azure(
426 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
427 ));
428 self
429 }
430
431 #[cfg(feature = "azure")]
433 pub fn build_azure(
434 &self,
435 url: &str,
436 clear_cached_credentials: bool,
437 ) -> PolarsResult<impl object_store::ObjectStore> {
438 use super::credential_provider::IntoCredentialProvider;
439
440 let verbose = polars_core::config::verbose();
441
442 let mut builder =
445 MicrosoftAzureBuilder::from_env().with_client_options(get_client_options());
446
447 if let Some(options) = &self.config {
448 let CloudConfig::Azure(options) = options else {
449 panic!("impl error: cloud type mismatch")
450 };
451 for (key, value) in options.iter() {
452 builder = builder.with_config(*key, value);
453 }
454 }
455
456 let builder = builder
457 .with_url(url)
458 .with_retry(get_retry_config(self.max_retries));
459
460 let builder =
461 if let Some(v) = self.initialized_credential_provider(clear_cached_credentials)? {
462 if verbose {
463 eprintln!(
464 "[CloudOptions::build_azure]: Using credential provider {:?}",
465 &v
466 );
467 }
468 builder.with_credentials(v.into_azure_provider())
469 } else {
470 builder
471 };
472
473 let out = builder.build()?;
474
475 Ok(out)
476 }
477
478 #[cfg(feature = "gcp")]
480 pub fn with_gcp<I: IntoIterator<Item = (GoogleConfigKey, impl Into<String>)>>(
481 mut self,
482 configs: I,
483 ) -> Self {
484 self.config = Some(CloudConfig::Gcp(
485 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
486 ));
487 self
488 }
489
490 #[cfg(feature = "gcp")]
492 pub fn build_gcp(
493 &self,
494 url: &str,
495 clear_cached_credentials: bool,
496 ) -> PolarsResult<impl object_store::ObjectStore> {
497 use super::credential_provider::IntoCredentialProvider;
498
499 let credential_provider = self.initialized_credential_provider(clear_cached_credentials)?;
500
501 let builder = if credential_provider.is_none() {
502 GoogleCloudStorageBuilder::from_env()
503 } else {
504 GoogleCloudStorageBuilder::new()
505 };
506
507 let mut builder = builder.with_client_options(get_client_options());
508
509 if let Some(options) = &self.config {
510 let CloudConfig::Gcp(options) = options else {
511 panic!("impl error: cloud type mismatch")
512 };
513 for (key, value) in options.iter() {
514 builder = builder.with_config(*key, value);
515 }
516 }
517
518 let builder = builder
519 .with_url(url)
520 .with_retry(get_retry_config(self.max_retries));
521
522 let builder = if let Some(v) = credential_provider {
523 builder.with_credentials(v.into_gcp_provider())
524 } else {
525 builder
526 };
527
528 let out = builder.build()?;
529
530 Ok(out)
531 }
532
533 #[cfg(feature = "http")]
534 pub fn build_http(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
535 let out = object_store::http::HttpBuilder::new()
536 .with_url(url)
537 .with_client_options({
538 let mut opts = super::get_client_options();
539 if let Some(CloudConfig::Http { headers }) = &self.config {
540 opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
541 headers.as_slice(),
542 )?);
543 }
544 opts
545 })
546 .build()?;
547
548 Ok(out)
549 }
550
551 #[allow(unused_variables)]
553 pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
554 scheme: Option<&CloudScheme>,
555 config: I,
556 ) -> PolarsResult<Self> {
557 match scheme.map_or(CloudType::File, CloudType::from_cloud_scheme) {
558 CloudType::Aws => {
559 #[cfg(feature = "aws")]
560 {
561 parse_untyped_config::<AmazonS3ConfigKey, _>(config)
562 .map(|aws| Self::default().with_aws(aws))
563 }
564 #[cfg(not(feature = "aws"))]
565 {
566 polars_bail!(ComputeError: "'aws' feature is not enabled");
567 }
568 },
569 CloudType::Azure => {
570 #[cfg(feature = "azure")]
571 {
572 parse_untyped_config::<AzureConfigKey, _>(config)
573 .map(|azure| Self::default().with_azure(azure))
574 }
575 #[cfg(not(feature = "azure"))]
576 {
577 polars_bail!(ComputeError: "'azure' feature is not enabled");
578 }
579 },
580 CloudType::File => Ok(Self::default()),
581 CloudType::Http => Ok(Self::default()),
582 CloudType::Gcp => {
583 #[cfg(feature = "gcp")]
584 {
585 parse_untyped_config::<GoogleConfigKey, _>(config)
586 .map(|gcp| Self::default().with_gcp(gcp))
587 }
588 #[cfg(not(feature = "gcp"))]
589 {
590 polars_bail!(ComputeError: "'gcp' feature is not enabled");
591 }
592 },
593 CloudType::Hf => {
594 #[cfg(feature = "http")]
595 {
596 use polars_core::config;
597
598 use crate::path_utils::resolve_homedir;
599
600 let mut this = Self::default();
601 let mut token = None;
602 let verbose = config::verbose();
603
604 for (i, (k, v)) in config.into_iter().enumerate() {
605 let (k, v) = (k.as_ref(), v.into());
606
607 if i == 0 && k == "token" {
608 if verbose {
609 eprintln!("HF token sourced from storage_options");
610 }
611 token = Some(v);
612 } else {
613 polars_bail!(ComputeError: "unknown configuration key for HF: {}", k)
614 }
615 }
616
617 token = token
618 .or_else(|| {
619 let v = std::env::var("HF_TOKEN").ok();
620 if v.is_some() && verbose {
621 eprintln!("HF token sourced from HF_TOKEN env var");
622 }
623 v
624 })
625 .or_else(|| {
626 let hf_home = std::env::var("HF_HOME");
627 let hf_home = hf_home.as_deref();
628 let hf_home = hf_home.unwrap_or("~/.cache/huggingface");
629 let hf_home = resolve_homedir(&hf_home);
630 let cached_token_path = hf_home.join("token");
631
632 let v = std::string::String::from_utf8(
633 std::fs::read(&cached_token_path).ok()?,
634 )
635 .ok()
636 .filter(|x| !x.is_empty());
637
638 if v.is_some() && verbose {
639 eprintln!(
640 "HF token sourced from {}",
641 cached_token_path.to_str().unwrap()
642 );
643 }
644
645 v
646 });
647
648 if let Some(v) = token {
649 this.config = Some(CloudConfig::Http {
650 headers: vec![("Authorization".into(), format!("Bearer {v}"))],
651 })
652 }
653
654 Ok(this)
655 }
656 #[cfg(not(feature = "http"))]
657 {
658 polars_bail!(ComputeError: "'http' feature is not enabled");
659 }
660 },
661 }
662 }
663
664 #[cfg(feature = "cloud")]
667 fn initialized_credential_provider(
668 &self,
669 clear_cached_credentials: bool,
670 ) -> PolarsResult<Option<PlCredentialProvider>> {
671 if let Some(v) = self.credential_provider.clone() {
672 v.try_into_initialized(clear_cached_credentials)
673 } else {
674 Ok(None)
675 }
676 }
677}
678
679#[cfg(feature = "cloud")]
680#[cfg(test)]
681mod tests {
682 use hashbrown::HashMap;
683
684 use super::parse_untyped_config;
685
686 #[cfg(feature = "aws")]
687 #[test]
688 fn test_parse_untyped_config() {
689 use object_store::aws::AmazonS3ConfigKey;
690
691 let aws_config = [
692 ("aws_secret_access_key", "a_key"),
693 ("aws_s3_allow_unsafe_rename", "true"),
694 ]
695 .into_iter()
696 .collect::<HashMap<_, _>>();
697 let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
698 .expect("Parsing keys shouldn't have thrown an error");
699
700 assert_eq!(
701 aws_keys.first().unwrap().0,
702 AmazonS3ConfigKey::SecretAccessKey
703 );
704 assert_eq!(aws_keys.len(), 1);
705
706 let aws_config = [
707 ("AWS_SECRET_ACCESS_KEY", "a_key"),
708 ("aws_s3_allow_unsafe_rename", "true"),
709 ]
710 .into_iter()
711 .collect::<HashMap<_, _>>();
712 let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
713 .expect("Parsing keys shouldn't have thrown an error");
714
715 assert_eq!(
716 aws_keys.first().unwrap().0,
717 AmazonS3ConfigKey::SecretAccessKey
718 );
719 assert_eq!(aws_keys.len(), 1);
720 }
721}