1#[cfg(feature = "aws")]
2use std::io::Read;
3#[cfg(feature = "aws")]
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::LazyLock;
7
8#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
9use object_store::ClientOptions;
10#[cfg(feature = "aws")]
11use object_store::aws::AmazonS3Builder;
12#[cfg(feature = "aws")]
13pub use object_store::aws::AmazonS3ConfigKey;
14#[cfg(feature = "azure")]
15pub use object_store::azure::AzureConfigKey;
16#[cfg(feature = "azure")]
17use object_store::azure::MicrosoftAzureBuilder;
18#[cfg(feature = "gcp")]
19use object_store::gcp::GoogleCloudStorageBuilder;
20#[cfg(feature = "gcp")]
21pub use object_store::gcp::GoogleConfigKey;
22#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
23use object_store::{BackoffConfig, RetryConfig};
24use polars_error::*;
25#[cfg(feature = "aws")]
26use polars_utils::cache::LruCache;
27use polars_utils::plpath::{CloudScheme, PlPathRef};
28#[cfg(feature = "http")]
29use reqwest::header::HeaderMap;
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32
33#[cfg(feature = "cloud")]
34use super::credential_provider::PlCredentialProvider;
35#[cfg(feature = "file_cache")]
36use crate::file_cache::get_env_file_cache_ttl;
37#[cfg(feature = "aws")]
38use crate::pl_async::with_concurrency_budget;
39
40#[cfg(feature = "aws")]
41static BUCKET_REGION: LazyLock<
42 std::sync::Mutex<LruCache<polars_utils::pl_str::PlSmallStr, polars_utils::pl_str::PlSmallStr>>,
43> = LazyLock::new(|| std::sync::Mutex::new(LruCache::with_capacity(32)));
44
45#[allow(dead_code)]
52type Configs<T> = Vec<(T, String)>;
53
54#[derive(Clone, Debug, PartialEq, Hash, Eq)]
55#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
57pub(crate) enum CloudConfig {
58 #[cfg(feature = "aws")]
59 Aws(
60 #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
61 Configs<AmazonS3ConfigKey>,
62 ),
63 #[cfg(feature = "azure")]
64 Azure(
65 #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
66 Configs<AzureConfigKey>,
67 ),
68 #[cfg(feature = "gcp")]
69 Gcp(
70 #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
71 Configs<GoogleConfigKey>,
72 ),
73 #[cfg(feature = "http")]
74 Http { headers: Vec<(String, String)> },
75}
76
77#[derive(Clone, Debug, PartialEq, Hash, Eq)]
78#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
79#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
80pub struct CloudOptions {
82 pub max_retries: usize,
83 #[cfg(feature = "file_cache")]
84 pub file_cache_ttl: u64,
85 pub(crate) config: Option<CloudConfig>,
86 #[cfg(feature = "cloud")]
87 pub(crate) credential_provider: Option<PlCredentialProvider>,
90}
91
92impl Default for CloudOptions {
93 fn default() -> Self {
94 Self::default_static_ref().clone()
95 }
96}
97
98impl CloudOptions {
99 pub fn default_static_ref() -> &'static Self {
100 static DEFAULT: LazyLock<CloudOptions> = LazyLock::new(|| CloudOptions {
101 max_retries: 2,
102 #[cfg(feature = "file_cache")]
103 file_cache_ttl: get_env_file_cache_ttl(),
104 config: None,
105 #[cfg(feature = "cloud")]
106 credential_provider: None,
107 });
108
109 &DEFAULT
110 }
111}
112
113#[cfg(feature = "http")]
114pub(crate) fn try_build_http_header_map_from_items_slice<S: AsRef<str>>(
115 headers: &[(S, S)],
116) -> PolarsResult<HeaderMap> {
117 use reqwest::header::{HeaderName, HeaderValue};
118
119 let mut map = HeaderMap::with_capacity(headers.len());
120 for (k, v) in headers {
121 let (k, v) = (k.as_ref(), v.as_ref());
122 map.insert(
123 HeaderName::from_str(k).map_err(to_compute_err)?,
124 HeaderValue::from_str(v).map_err(to_compute_err)?,
125 );
126 }
127
128 Ok(map)
129}
130
131#[allow(dead_code)]
132fn parse_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
134 config: I,
135) -> PolarsResult<Configs<T>>
136where
137 T: FromStr + Eq + std::hash::Hash,
138{
139 Ok(config
140 .into_iter()
141 .filter_map(|(key, val)| {
143 T::from_str(key.as_ref().to_ascii_lowercase().as_str())
144 .ok()
145 .map(|typed_key| (typed_key, val.into()))
146 })
147 .collect::<Configs<T>>())
148}
149
150#[derive(Debug, Clone, PartialEq)]
151pub enum CloudType {
152 Aws,
153 Azure,
154 File,
155 Gcp,
157 Http,
158 Hf,
160}
161
162impl CloudType {
163 pub fn from_cloud_scheme(scheme: CloudScheme) -> Self {
164 match scheme {
165 CloudScheme::Abfs
166 | CloudScheme::Abfss
167 | CloudScheme::Adl
168 | CloudScheme::Az
169 | CloudScheme::Azure => Self::Azure,
170
171 CloudScheme::File | CloudScheme::FileNoHostname => Self::File,
172
173 CloudScheme::Gcs | CloudScheme::Gs => Self::Gcp,
174
175 CloudScheme::Hf => Self::Hf,
176
177 CloudScheme::Http | CloudScheme::Https => Self::Http,
178
179 CloudScheme::S3 | CloudScheme::S3a => Self::Aws,
180 }
181 }
182}
183
184#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
185fn get_retry_config(max_retries: usize) -> RetryConfig {
186 RetryConfig {
187 backoff: BackoffConfig::default(),
188 max_retries,
189 retry_timeout: std::time::Duration::from_secs(10),
190 }
191}
192
193pub static USER_AGENT: &str = concat!("polars", "/", env!("CARGO_PKG_VERSION"),);
194
195#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
196pub(super) fn get_client_options() -> ClientOptions {
197 use std::num::NonZeroU64;
198
199 use reqwest::header::HeaderValue;
200
201 ClientOptions::new()
202 .with_timeout(std::time::Duration::from_secs(
206 std::env::var("POLARS_HTTP_CLIENT_TIMEOUT_SECONDS")
207 .map(|x| {
208 x.parse::<NonZeroU64>()
209 .ok()
210 .unwrap_or_else(|| {
211 panic!("invalid value for POLARS_HTTP_CLIENT_TIMEOUT_SECONDS: {x}")
212 })
213 .get()
214 })
215 .unwrap_or(10 * 60),
216 ))
217 .with_connect_timeout(std::time::Duration::from_secs(
219 std::env::var("POLARS_HTTP_CONNECT_TIMEOUT_SECONDS")
220 .map(|x| {
221 x.parse::<NonZeroU64>()
222 .ok()
223 .unwrap_or_else(|| {
224 panic!("invalid value for POLARS_HTTP_CONNECT_TIMEOUT_SECONDS: {x}")
225 })
226 .get()
227 })
228 .unwrap_or(10 * 60),
229 ))
230 .with_user_agent(HeaderValue::from_static(USER_AGENT))
231 .with_allow_http(true)
232}
233
234#[cfg(feature = "aws")]
235fn read_config(
236 builder: &mut AmazonS3Builder,
237 items: &[(&Path, &[(&str, AmazonS3ConfigKey)])],
238) -> Option<()> {
239 use crate::path_utils::resolve_homedir;
240
241 for (path, keys) in items {
242 if keys
243 .iter()
244 .all(|(_, key)| builder.get_config_value(key).is_some())
245 {
246 continue;
247 }
248
249 let mut config = std::fs::File::open(resolve_homedir(path)).ok()?;
250 let mut buf = vec![];
251 config.read_to_end(&mut buf).ok()?;
252 let content = std::str::from_utf8(buf.as_ref()).ok()?;
253
254 for (pattern, key) in keys.iter() {
255 if builder.get_config_value(key).is_none() {
256 let reg = polars_utils::regex_cache::compile_regex(pattern).unwrap();
257 let cap = reg.captures(content)?;
258 let m = cap.get(1)?;
259 let parsed = m.as_str();
260 *builder = std::mem::take(builder).with_config(*key, parsed);
261 }
262 }
263 }
264 Some(())
265}
266
267impl CloudOptions {
268 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
270 self.max_retries = max_retries;
271 self
272 }
273
274 #[cfg(feature = "cloud")]
275 pub fn with_credential_provider(
276 mut self,
277 credential_provider: Option<PlCredentialProvider>,
278 ) -> Self {
279 self.credential_provider = credential_provider;
280 self
281 }
282
283 #[cfg(feature = "aws")]
285 pub fn with_aws<I: IntoIterator<Item = (AmazonS3ConfigKey, impl Into<String>)>>(
286 mut self,
287 configs: I,
288 ) -> Self {
289 self.config = Some(CloudConfig::Aws(
290 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
291 ));
292 self
293 }
294
295 #[cfg(feature = "aws")]
297 pub async fn build_aws(
298 &self,
299 url: &str,
300 clear_cached_credentials: bool,
301 ) -> PolarsResult<impl object_store::ObjectStore> {
302 use super::credential_provider::IntoCredentialProvider;
303
304 let opt_credential_provider =
305 self.initialized_credential_provider(clear_cached_credentials)?;
306
307 let mut builder = AmazonS3Builder::from_env()
308 .with_client_options(get_client_options())
309 .with_url(url);
310
311 if let Some(credential_provider) = &opt_credential_provider {
312 let storage_update_options = parse_untyped_config::<AmazonS3ConfigKey, _>(
313 credential_provider
314 .storage_update_options()?
315 .into_iter()
316 .map(|(k, v)| (k, v.to_string())),
317 )?;
318
319 for (key, value) in storage_update_options {
320 builder = builder.with_config(key, value);
321 }
322 }
323
324 read_config(
325 &mut builder,
326 &[(
327 Path::new("~/.aws/config"),
328 &[("region\\s*=\\s*([^\r\n]*)", AmazonS3ConfigKey::Region)],
329 )],
330 );
331
332 read_config(
333 &mut builder,
334 &[(
335 Path::new("~/.aws/credentials"),
336 &[
337 (
338 "aws_access_key_id\\s*=\\s*([^\\r\\n]*)",
339 AmazonS3ConfigKey::AccessKeyId,
340 ),
341 (
342 "aws_secret_access_key\\s*=\\s*([^\\r\\n]*)",
343 AmazonS3ConfigKey::SecretAccessKey,
344 ),
345 (
346 "aws_session_token\\s*=\\s*([^\\r\\n]*)",
347 AmazonS3ConfigKey::Token,
348 ),
349 ],
350 )],
351 );
352
353 if let Some(options) = &self.config {
354 let CloudConfig::Aws(options) = options else {
355 panic!("impl error: cloud type mismatch")
356 };
357 for (key, value) in options {
358 builder = builder.with_config(*key, value);
359 }
360 }
361
362 if builder
363 .get_config_value(&AmazonS3ConfigKey::DefaultRegion)
364 .is_none()
365 && builder
366 .get_config_value(&AmazonS3ConfigKey::Region)
367 .is_none()
368 {
369 let bucket = crate::cloud::CloudLocation::new(PlPathRef::new(url), false)?.bucket;
370 let region = {
371 let mut bucket_region = BUCKET_REGION.lock().unwrap();
372 bucket_region.get(bucket.as_str()).cloned()
373 };
374
375 match region {
376 Some(region) => {
377 builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str())
378 },
379 None => {
380 if builder
381 .get_config_value(&AmazonS3ConfigKey::Endpoint)
382 .is_some()
383 {
384 builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1");
387 } else {
388 polars_warn!(
389 "'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning."
390 );
391 let result = with_concurrency_budget(1, || async {
392 reqwest::Client::builder()
393 .user_agent(USER_AGENT)
394 .build()
395 .unwrap()
396 .head(format!("https://{bucket}.s3.amazonaws.com"))
397 .send()
398 .await
399 .map_err(to_compute_err)
400 })
401 .await?;
402 if let Some(region) = result.headers().get("x-amz-bucket-region") {
403 let region =
404 std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?;
405 let mut bucket_region = BUCKET_REGION.lock().unwrap();
406 bucket_region.insert(bucket, region.into());
407 builder = builder.with_config(AmazonS3ConfigKey::Region, region)
408 }
409 }
410 },
411 };
412 };
413
414 let builder = builder.with_retry(get_retry_config(self.max_retries));
415
416 let opt_credential_provider = match opt_credential_provider {
417 #[cfg(feature = "python")]
418 Some(PlCredentialProvider::Python(object)) => {
419 if pyo3::Python::attach(|py| {
420 let Ok(func_object) = object
421 .unwrap_as_provider_ref()
422 .getattr(py, "_can_use_as_provider")
423 else {
424 return PolarsResult::Ok(true);
425 };
426
427 Ok(func_object.call0(py)?.extract::<bool>(py).unwrap())
428 })? {
429 Some(PlCredentialProvider::Python(object))
430 } else {
431 None
432 }
433 },
434
435 v => v,
436 };
437
438 let builder = if let Some(credential_provider) = opt_credential_provider {
439 builder.with_credentials(credential_provider.into_aws_provider())
440 } else {
441 builder
442 };
443
444 let out = builder.build()?;
445
446 Ok(out)
447 }
448
449 #[cfg(feature = "azure")]
451 pub fn with_azure<I: IntoIterator<Item = (AzureConfigKey, impl Into<String>)>>(
452 mut self,
453 configs: I,
454 ) -> Self {
455 self.config = Some(CloudConfig::Azure(
456 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
457 ));
458 self
459 }
460
461 #[cfg(feature = "azure")]
463 pub fn build_azure(
464 &self,
465 url: &str,
466 clear_cached_credentials: bool,
467 ) -> PolarsResult<impl object_store::ObjectStore> {
468 use super::credential_provider::IntoCredentialProvider;
469
470 let verbose = polars_core::config::verbose();
471
472 let mut builder =
475 MicrosoftAzureBuilder::from_env().with_client_options(get_client_options());
476
477 if let Some(options) = &self.config {
478 let CloudConfig::Azure(options) = options else {
479 panic!("impl error: cloud type mismatch")
480 };
481 for (key, value) in options.iter() {
482 builder = builder.with_config(*key, value);
483 }
484 }
485
486 let builder = builder
487 .with_url(url)
488 .with_retry(get_retry_config(self.max_retries));
489
490 let builder =
491 if let Some(v) = self.initialized_credential_provider(clear_cached_credentials)? {
492 if verbose {
493 eprintln!(
494 "[CloudOptions::build_azure]: Using credential provider {:?}",
495 &v
496 );
497 }
498 builder.with_credentials(v.into_azure_provider())
499 } else {
500 builder
501 };
502
503 let out = builder.build()?;
504
505 Ok(out)
506 }
507
508 #[cfg(feature = "gcp")]
510 pub fn with_gcp<I: IntoIterator<Item = (GoogleConfigKey, impl Into<String>)>>(
511 mut self,
512 configs: I,
513 ) -> Self {
514 self.config = Some(CloudConfig::Gcp(
515 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
516 ));
517 self
518 }
519
520 #[cfg(feature = "gcp")]
522 pub fn build_gcp(
523 &self,
524 url: &str,
525 clear_cached_credentials: bool,
526 ) -> PolarsResult<impl object_store::ObjectStore> {
527 use super::credential_provider::IntoCredentialProvider;
528
529 let credential_provider = self.initialized_credential_provider(clear_cached_credentials)?;
530
531 let builder = if credential_provider.is_none() {
532 GoogleCloudStorageBuilder::from_env()
533 } else {
534 GoogleCloudStorageBuilder::new()
535 };
536
537 let mut builder = builder.with_client_options(get_client_options());
538
539 if let Some(options) = &self.config {
540 let CloudConfig::Gcp(options) = options else {
541 panic!("impl error: cloud type mismatch")
542 };
543 for (key, value) in options.iter() {
544 builder = builder.with_config(*key, value);
545 }
546 }
547
548 let builder = builder
549 .with_url(url)
550 .with_retry(get_retry_config(self.max_retries));
551
552 let builder = if let Some(v) = credential_provider {
553 builder.with_credentials(v.into_gcp_provider())
554 } else {
555 builder
556 };
557
558 let out = builder.build()?;
559
560 Ok(out)
561 }
562
563 #[cfg(feature = "http")]
564 pub fn build_http(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
565 let out = object_store::http::HttpBuilder::new()
566 .with_url(url)
567 .with_client_options({
568 let mut opts = super::get_client_options();
569 if let Some(CloudConfig::Http { headers }) = &self.config {
570 opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
571 headers.as_slice(),
572 )?);
573 }
574 opts
575 })
576 .build()?;
577
578 Ok(out)
579 }
580
581 #[allow(unused_variables)]
583 pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
584 scheme: Option<CloudScheme>,
585 config: I,
586 ) -> PolarsResult<Self> {
587 match scheme.map_or(CloudType::File, CloudType::from_cloud_scheme) {
588 CloudType::Aws => {
589 #[cfg(feature = "aws")]
590 {
591 parse_untyped_config::<AmazonS3ConfigKey, _>(config)
592 .map(|aws| Self::default().with_aws(aws))
593 }
594 #[cfg(not(feature = "aws"))]
595 {
596 polars_bail!(ComputeError: "'aws' feature is not enabled");
597 }
598 },
599 CloudType::Azure => {
600 #[cfg(feature = "azure")]
601 {
602 parse_untyped_config::<AzureConfigKey, _>(config)
603 .map(|azure| Self::default().with_azure(azure))
604 }
605 #[cfg(not(feature = "azure"))]
606 {
607 polars_bail!(ComputeError: "'azure' feature is not enabled");
608 }
609 },
610 CloudType::File => Ok(Self::default()),
611 CloudType::Http => Ok(Self::default()),
612 CloudType::Gcp => {
613 #[cfg(feature = "gcp")]
614 {
615 parse_untyped_config::<GoogleConfigKey, _>(config)
616 .map(|gcp| Self::default().with_gcp(gcp))
617 }
618 #[cfg(not(feature = "gcp"))]
619 {
620 polars_bail!(ComputeError: "'gcp' feature is not enabled");
621 }
622 },
623 CloudType::Hf => {
624 #[cfg(feature = "http")]
625 {
626 use polars_core::config;
627
628 use crate::path_utils::resolve_homedir;
629
630 let mut this = Self::default();
631 let mut token = None;
632 let verbose = config::verbose();
633
634 for (i, (k, v)) in config.into_iter().enumerate() {
635 let (k, v) = (k.as_ref(), v.into());
636
637 if i == 0 && k == "token" {
638 if verbose {
639 eprintln!("HF token sourced from storage_options");
640 }
641 token = Some(v);
642 } else {
643 polars_bail!(ComputeError: "unknown configuration key for HF: {}", k)
644 }
645 }
646
647 token = token
648 .or_else(|| {
649 let v = std::env::var("HF_TOKEN").ok();
650 if v.is_some() && verbose {
651 eprintln!("HF token sourced from HF_TOKEN env var");
652 }
653 v
654 })
655 .or_else(|| {
656 let hf_home = std::env::var("HF_HOME");
657 let hf_home = hf_home.as_deref();
658 let hf_home = hf_home.unwrap_or("~/.cache/huggingface");
659 let hf_home = resolve_homedir(&hf_home);
660 let cached_token_path = hf_home.join("token");
661
662 let v = std::string::String::from_utf8(
663 std::fs::read(&cached_token_path).ok()?,
664 )
665 .ok()
666 .filter(|x| !x.is_empty());
667
668 if v.is_some() && verbose {
669 eprintln!(
670 "HF token sourced from {}",
671 cached_token_path.to_str().unwrap()
672 );
673 }
674
675 v
676 });
677
678 if let Some(v) = token {
679 this.config = Some(CloudConfig::Http {
680 headers: vec![("Authorization".into(), format!("Bearer {v}"))],
681 })
682 }
683
684 Ok(this)
685 }
686 #[cfg(not(feature = "http"))]
687 {
688 polars_bail!(ComputeError: "'http' feature is not enabled");
689 }
690 },
691 }
692 }
693
694 #[cfg(feature = "cloud")]
697 fn initialized_credential_provider(
698 &self,
699 clear_cached_credentials: bool,
700 ) -> PolarsResult<Option<PlCredentialProvider>> {
701 if let Some(v) = self.credential_provider.clone() {
702 v.try_into_initialized(clear_cached_credentials)
703 } else {
704 Ok(None)
705 }
706 }
707}
708
709#[cfg(feature = "cloud")]
710#[cfg(test)]
711mod tests {
712 use hashbrown::HashMap;
713
714 use super::parse_untyped_config;
715
716 #[cfg(feature = "aws")]
717 #[test]
718 fn test_parse_untyped_config() {
719 use object_store::aws::AmazonS3ConfigKey;
720
721 let aws_config = [
722 ("aws_secret_access_key", "a_key"),
723 ("aws_s3_allow_unsafe_rename", "true"),
724 ]
725 .into_iter()
726 .collect::<HashMap<_, _>>();
727 let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
728 .expect("Parsing keys shouldn't have thrown an error");
729
730 assert_eq!(
731 aws_keys.first().unwrap().0,
732 AmazonS3ConfigKey::SecretAccessKey
733 );
734 assert_eq!(aws_keys.len(), 1);
735
736 let aws_config = [
737 ("AWS_SECRET_ACCESS_KEY", "a_key"),
738 ("aws_s3_allow_unsafe_rename", "true"),
739 ]
740 .into_iter()
741 .collect::<HashMap<_, _>>();
742 let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
743 .expect("Parsing keys shouldn't have thrown an error");
744
745 assert_eq!(
746 aws_keys.first().unwrap().0,
747 AmazonS3ConfigKey::SecretAccessKey
748 );
749 assert_eq!(aws_keys.len(), 1);
750 }
751}