polars_io/cloud/
options.rs

1#[cfg(feature = "aws")]
2use std::io::Read;
3#[cfg(feature = "aws")]
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::LazyLock;
7
8#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
9use object_store::ClientOptions;
10#[cfg(feature = "aws")]
11use object_store::aws::AmazonS3Builder;
12#[cfg(feature = "aws")]
13pub use object_store::aws::AmazonS3ConfigKey;
14#[cfg(feature = "azure")]
15pub use object_store::azure::AzureConfigKey;
16#[cfg(feature = "azure")]
17use object_store::azure::MicrosoftAzureBuilder;
18#[cfg(feature = "gcp")]
19use object_store::gcp::GoogleCloudStorageBuilder;
20#[cfg(feature = "gcp")]
21pub use object_store::gcp::GoogleConfigKey;
22#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
23use object_store::{BackoffConfig, RetryConfig};
24use polars_error::*;
25#[cfg(feature = "aws")]
26use polars_utils::cache::LruCache;
27use polars_utils::plpath::{CloudScheme, PlPathRef};
28#[cfg(feature = "http")]
29use reqwest::header::HeaderMap;
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32
33#[cfg(feature = "cloud")]
34use super::credential_provider::PlCredentialProvider;
35#[cfg(feature = "file_cache")]
36use crate::file_cache::get_env_file_cache_ttl;
37#[cfg(feature = "aws")]
38use crate::pl_async::with_concurrency_budget;
39
40#[cfg(feature = "aws")]
41static BUCKET_REGION: LazyLock<
42    std::sync::Mutex<LruCache<polars_utils::pl_str::PlSmallStr, polars_utils::pl_str::PlSmallStr>>,
43> = LazyLock::new(|| std::sync::Mutex::new(LruCache::with_capacity(32)));
44
45/// The type of the config keys must satisfy the following requirements:
46/// 1. must be easily collected into a HashMap, the type required by the object_crate API.
47/// 2. be Serializable, required when the serde-lazy feature is defined.
48/// 3. not actually use HashMap since that type is disallowed in Polars for performance reasons.
49///
50/// Currently this type is a vector of pairs config key - config value.
51#[allow(dead_code)]
52type Configs<T> = Vec<(T, String)>;
53
54#[derive(Clone, Debug, PartialEq, Hash, Eq)]
55#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
57pub(crate) enum CloudConfig {
58    #[cfg(feature = "aws")]
59    Aws(
60        #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
61        Configs<AmazonS3ConfigKey>,
62    ),
63    #[cfg(feature = "azure")]
64    Azure(
65        #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
66        Configs<AzureConfigKey>,
67    ),
68    #[cfg(feature = "gcp")]
69    Gcp(
70        #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
71        Configs<GoogleConfigKey>,
72    ),
73    #[cfg(feature = "http")]
74    Http { headers: Vec<(String, String)> },
75}
76
77#[derive(Clone, Debug, PartialEq, Hash, Eq)]
78#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
79#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
80/// Options to connect to various cloud providers.
81pub struct CloudOptions {
82    pub max_retries: usize,
83    #[cfg(feature = "file_cache")]
84    pub file_cache_ttl: u64,
85    pub(crate) config: Option<CloudConfig>,
86    #[cfg(feature = "cloud")]
87    /// Note: In most cases you will want to access this via [`CloudOptions::initialized_credential_provider`]
88    /// rather than directly.
89    pub(crate) credential_provider: Option<PlCredentialProvider>,
90}
91
92impl Default for CloudOptions {
93    fn default() -> Self {
94        Self::default_static_ref().clone()
95    }
96}
97
98impl CloudOptions {
99    pub fn default_static_ref() -> &'static Self {
100        static DEFAULT: LazyLock<CloudOptions> = LazyLock::new(|| CloudOptions {
101            max_retries: 2,
102            #[cfg(feature = "file_cache")]
103            file_cache_ttl: get_env_file_cache_ttl(),
104            config: None,
105            #[cfg(feature = "cloud")]
106            credential_provider: None,
107        });
108
109        &DEFAULT
110    }
111}
112
113#[cfg(feature = "http")]
114pub(crate) fn try_build_http_header_map_from_items_slice<S: AsRef<str>>(
115    headers: &[(S, S)],
116) -> PolarsResult<HeaderMap> {
117    use reqwest::header::{HeaderName, HeaderValue};
118
119    let mut map = HeaderMap::with_capacity(headers.len());
120    for (k, v) in headers {
121        let (k, v) = (k.as_ref(), v.as_ref());
122        map.insert(
123            HeaderName::from_str(k).map_err(to_compute_err)?,
124            HeaderValue::from_str(v).map_err(to_compute_err)?,
125        );
126    }
127
128    Ok(map)
129}
130
131#[allow(dead_code)]
132/// Parse an untype configuration hashmap to a typed configuration for the given configuration key type.
133fn parse_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
134    config: I,
135) -> PolarsResult<Configs<T>>
136where
137    T: FromStr + Eq + std::hash::Hash,
138{
139    Ok(config
140        .into_iter()
141        // Silently ignores custom upstream storage_options
142        .filter_map(|(key, val)| {
143            T::from_str(key.as_ref().to_ascii_lowercase().as_str())
144                .ok()
145                .map(|typed_key| (typed_key, val.into()))
146        })
147        .collect::<Configs<T>>())
148}
149
150#[derive(Debug, Clone, PartialEq)]
151pub enum CloudType {
152    Aws,
153    Azure,
154    File,
155    /// Google cloud platform
156    Gcp,
157    Http,
158    /// HuggingFace
159    Hf,
160}
161
162impl CloudType {
163    pub fn from_cloud_scheme(scheme: &CloudScheme) -> Self {
164        match scheme {
165            CloudScheme::Abfs
166            | CloudScheme::Abfss
167            | CloudScheme::Adl
168            | CloudScheme::Az
169            | CloudScheme::Azure => Self::Azure,
170
171            CloudScheme::File | CloudScheme::FileNoHostname => Self::File,
172
173            CloudScheme::Gcs | CloudScheme::Gs => Self::Gcp,
174
175            CloudScheme::Hf => Self::Hf,
176
177            CloudScheme::Http | CloudScheme::Https => Self::Http,
178
179            CloudScheme::S3 | CloudScheme::S3a => Self::Aws,
180        }
181    }
182}
183
184#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
185fn get_retry_config(max_retries: usize) -> RetryConfig {
186    RetryConfig {
187        backoff: BackoffConfig::default(),
188        max_retries,
189        retry_timeout: std::time::Duration::from_secs(10),
190    }
191}
192
193#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
194pub(super) fn get_client_options() -> ClientOptions {
195    ClientOptions::new()
196        // We set request timeout super high as the timeout isn't reset at ACK,
197        // but starts from the moment we start downloading a body.
198        // https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#method.timeout
199        .with_timeout_disabled()
200        // Concurrency can increase connection latency, so set to None, similar to default.
201        .with_connect_timeout_disabled()
202        .with_allow_http(true)
203}
204
205#[cfg(feature = "aws")]
206fn read_config(
207    builder: &mut AmazonS3Builder,
208    items: &[(&Path, &[(&str, AmazonS3ConfigKey)])],
209) -> Option<()> {
210    use crate::path_utils::resolve_homedir;
211
212    for (path, keys) in items {
213        if keys
214            .iter()
215            .all(|(_, key)| builder.get_config_value(key).is_some())
216        {
217            continue;
218        }
219
220        let mut config = std::fs::File::open(resolve_homedir(path)).ok()?;
221        let mut buf = vec![];
222        config.read_to_end(&mut buf).ok()?;
223        let content = std::str::from_utf8(buf.as_ref()).ok()?;
224
225        for (pattern, key) in keys.iter() {
226            if builder.get_config_value(key).is_none() {
227                let reg = polars_utils::regex_cache::compile_regex(pattern).unwrap();
228                let cap = reg.captures(content)?;
229                let m = cap.get(1)?;
230                let parsed = m.as_str();
231                *builder = std::mem::take(builder).with_config(*key, parsed);
232            }
233        }
234    }
235    Some(())
236}
237
238impl CloudOptions {
239    /// Set the maximum number of retries.
240    pub fn with_max_retries(mut self, max_retries: usize) -> Self {
241        self.max_retries = max_retries;
242        self
243    }
244
245    #[cfg(feature = "cloud")]
246    pub fn with_credential_provider(
247        mut self,
248        credential_provider: Option<PlCredentialProvider>,
249    ) -> Self {
250        self.credential_provider = credential_provider;
251        self
252    }
253
254    /// Set the configuration for AWS connections. This is the preferred API from rust.
255    #[cfg(feature = "aws")]
256    pub fn with_aws<I: IntoIterator<Item = (AmazonS3ConfigKey, impl Into<String>)>>(
257        mut self,
258        configs: I,
259    ) -> Self {
260        self.config = Some(CloudConfig::Aws(
261            configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
262        ));
263        self
264    }
265
266    /// Build the [`object_store::ObjectStore`] implementation for AWS.
267    #[cfg(feature = "aws")]
268    pub async fn build_aws(
269        &self,
270        url: &str,
271        clear_cached_credentials: bool,
272    ) -> PolarsResult<impl object_store::ObjectStore> {
273        use super::credential_provider::IntoCredentialProvider;
274
275        let opt_credential_provider =
276            self.initialized_credential_provider(clear_cached_credentials)?;
277
278        let mut builder = AmazonS3Builder::from_env()
279            .with_client_options(get_client_options())
280            .with_url(url);
281
282        if let Some(credential_provider) = &opt_credential_provider {
283            let storage_update_options = parse_untyped_config::<AmazonS3ConfigKey, _>(
284                credential_provider
285                    .storage_update_options()?
286                    .into_iter()
287                    .map(|(k, v)| (k, v.to_string())),
288            )?;
289
290            for (key, value) in storage_update_options {
291                builder = builder.with_config(key, value);
292            }
293        }
294
295        read_config(
296            &mut builder,
297            &[(
298                Path::new("~/.aws/config"),
299                &[("region\\s*=\\s*([^\r\n]*)", AmazonS3ConfigKey::Region)],
300            )],
301        );
302
303        read_config(
304            &mut builder,
305            &[(
306                Path::new("~/.aws/credentials"),
307                &[
308                    (
309                        "aws_access_key_id\\s*=\\s*([^\\r\\n]*)",
310                        AmazonS3ConfigKey::AccessKeyId,
311                    ),
312                    (
313                        "aws_secret_access_key\\s*=\\s*([^\\r\\n]*)",
314                        AmazonS3ConfigKey::SecretAccessKey,
315                    ),
316                    (
317                        "aws_session_token\\s*=\\s*([^\\r\\n]*)",
318                        AmazonS3ConfigKey::Token,
319                    ),
320                ],
321            )],
322        );
323
324        if let Some(options) = &self.config {
325            let CloudConfig::Aws(options) = options else {
326                panic!("impl error: cloud type mismatch")
327            };
328            for (key, value) in options {
329                builder = builder.with_config(*key, value);
330            }
331        }
332
333        if builder
334            .get_config_value(&AmazonS3ConfigKey::DefaultRegion)
335            .is_none()
336            && builder
337                .get_config_value(&AmazonS3ConfigKey::Region)
338                .is_none()
339        {
340            let bucket = crate::cloud::CloudLocation::new(PlPathRef::new(url), false)?.bucket;
341            let region = {
342                let mut bucket_region = BUCKET_REGION.lock().unwrap();
343                bucket_region.get(bucket.as_str()).cloned()
344            };
345
346            match region {
347                Some(region) => {
348                    builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str())
349                },
350                None => {
351                    if builder
352                        .get_config_value(&AmazonS3ConfigKey::Endpoint)
353                        .is_some()
354                    {
355                        // Set a default value if the endpoint is not aws.
356                        // See: #13042
357                        builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1");
358                    } else {
359                        polars_warn!(
360                            "'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning."
361                        );
362                        let result = with_concurrency_budget(1, || async {
363                            reqwest::Client::builder()
364                                .build()
365                                .unwrap()
366                                .head(format!("https://{bucket}.s3.amazonaws.com"))
367                                .send()
368                                .await
369                                .map_err(to_compute_err)
370                        })
371                        .await?;
372                        if let Some(region) = result.headers().get("x-amz-bucket-region") {
373                            let region =
374                                std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?;
375                            let mut bucket_region = BUCKET_REGION.lock().unwrap();
376                            bucket_region.insert(bucket, region.into());
377                            builder = builder.with_config(AmazonS3ConfigKey::Region, region)
378                        }
379                    }
380                },
381            };
382        };
383
384        let builder = builder.with_retry(get_retry_config(self.max_retries));
385
386        let opt_credential_provider = match opt_credential_provider {
387            #[cfg(feature = "python")]
388            Some(PlCredentialProvider::Python(object)) => {
389                if pyo3::Python::with_gil(|py| {
390                    let Ok(func_object) = object
391                        .unwrap_as_provider_ref()
392                        .getattr(py, "_can_use_as_provider")
393                    else {
394                        return PolarsResult::Ok(true);
395                    };
396
397                    Ok(func_object.call0(py)?.extract::<bool>(py).unwrap())
398                })? {
399                    Some(PlCredentialProvider::Python(object))
400                } else {
401                    None
402                }
403            },
404
405            v => v,
406        };
407
408        let builder = if let Some(credential_provider) = opt_credential_provider {
409            builder.with_credentials(credential_provider.into_aws_provider())
410        } else {
411            builder
412        };
413
414        let out = builder.build()?;
415
416        Ok(out)
417    }
418
419    /// Set the configuration for Azure connections. This is the preferred API from rust.
420    #[cfg(feature = "azure")]
421    pub fn with_azure<I: IntoIterator<Item = (AzureConfigKey, impl Into<String>)>>(
422        mut self,
423        configs: I,
424    ) -> Self {
425        self.config = Some(CloudConfig::Azure(
426            configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
427        ));
428        self
429    }
430
431    /// Build the [`object_store::ObjectStore`] implementation for Azure.
432    #[cfg(feature = "azure")]
433    pub fn build_azure(
434        &self,
435        url: &str,
436        clear_cached_credentials: bool,
437    ) -> PolarsResult<impl object_store::ObjectStore> {
438        use super::credential_provider::IntoCredentialProvider;
439
440        let verbose = polars_core::config::verbose();
441
442        // The credential provider `self.credentials` is prioritized if it is set. We also need
443        // `from_env()` as it may source environment configured storage account name.
444        let mut builder =
445            MicrosoftAzureBuilder::from_env().with_client_options(get_client_options());
446
447        if let Some(options) = &self.config {
448            let CloudConfig::Azure(options) = options else {
449                panic!("impl error: cloud type mismatch")
450            };
451            for (key, value) in options.iter() {
452                builder = builder.with_config(*key, value);
453            }
454        }
455
456        let builder = builder
457            .with_url(url)
458            .with_retry(get_retry_config(self.max_retries));
459
460        let builder =
461            if let Some(v) = self.initialized_credential_provider(clear_cached_credentials)? {
462                if verbose {
463                    eprintln!(
464                        "[CloudOptions::build_azure]: Using credential provider {:?}",
465                        &v
466                    );
467                }
468                builder.with_credentials(v.into_azure_provider())
469            } else {
470                builder
471            };
472
473        let out = builder.build()?;
474
475        Ok(out)
476    }
477
478    /// Set the configuration for GCP connections. This is the preferred API from rust.
479    #[cfg(feature = "gcp")]
480    pub fn with_gcp<I: IntoIterator<Item = (GoogleConfigKey, impl Into<String>)>>(
481        mut self,
482        configs: I,
483    ) -> Self {
484        self.config = Some(CloudConfig::Gcp(
485            configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
486        ));
487        self
488    }
489
490    /// Build the [`object_store::ObjectStore`] implementation for GCP.
491    #[cfg(feature = "gcp")]
492    pub fn build_gcp(
493        &self,
494        url: &str,
495        clear_cached_credentials: bool,
496    ) -> PolarsResult<impl object_store::ObjectStore> {
497        use super::credential_provider::IntoCredentialProvider;
498
499        let credential_provider = self.initialized_credential_provider(clear_cached_credentials)?;
500
501        let builder = if credential_provider.is_none() {
502            GoogleCloudStorageBuilder::from_env()
503        } else {
504            GoogleCloudStorageBuilder::new()
505        };
506
507        let mut builder = builder.with_client_options(get_client_options());
508
509        if let Some(options) = &self.config {
510            let CloudConfig::Gcp(options) = options else {
511                panic!("impl error: cloud type mismatch")
512            };
513            for (key, value) in options.iter() {
514                builder = builder.with_config(*key, value);
515            }
516        }
517
518        let builder = builder
519            .with_url(url)
520            .with_retry(get_retry_config(self.max_retries));
521
522        let builder = if let Some(v) = credential_provider {
523            builder.with_credentials(v.into_gcp_provider())
524        } else {
525            builder
526        };
527
528        let out = builder.build()?;
529
530        Ok(out)
531    }
532
533    #[cfg(feature = "http")]
534    pub fn build_http(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
535        let out = object_store::http::HttpBuilder::new()
536            .with_url(url)
537            .with_client_options({
538                let mut opts = super::get_client_options();
539                if let Some(CloudConfig::Http { headers }) = &self.config {
540                    opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
541                        headers.as_slice(),
542                    )?);
543                }
544                opts
545            })
546            .build()?;
547
548        Ok(out)
549    }
550
551    /// Parse a configuration from a Hashmap. This is the interface from Python.
552    #[allow(unused_variables)]
553    pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
554        scheme: Option<&CloudScheme>,
555        config: I,
556    ) -> PolarsResult<Self> {
557        match scheme.map_or(CloudType::File, CloudType::from_cloud_scheme) {
558            CloudType::Aws => {
559                #[cfg(feature = "aws")]
560                {
561                    parse_untyped_config::<AmazonS3ConfigKey, _>(config)
562                        .map(|aws| Self::default().with_aws(aws))
563                }
564                #[cfg(not(feature = "aws"))]
565                {
566                    polars_bail!(ComputeError: "'aws' feature is not enabled");
567                }
568            },
569            CloudType::Azure => {
570                #[cfg(feature = "azure")]
571                {
572                    parse_untyped_config::<AzureConfigKey, _>(config)
573                        .map(|azure| Self::default().with_azure(azure))
574                }
575                #[cfg(not(feature = "azure"))]
576                {
577                    polars_bail!(ComputeError: "'azure' feature is not enabled");
578                }
579            },
580            CloudType::File => Ok(Self::default()),
581            CloudType::Http => Ok(Self::default()),
582            CloudType::Gcp => {
583                #[cfg(feature = "gcp")]
584                {
585                    parse_untyped_config::<GoogleConfigKey, _>(config)
586                        .map(|gcp| Self::default().with_gcp(gcp))
587                }
588                #[cfg(not(feature = "gcp"))]
589                {
590                    polars_bail!(ComputeError: "'gcp' feature is not enabled");
591                }
592            },
593            CloudType::Hf => {
594                #[cfg(feature = "http")]
595                {
596                    use polars_core::config;
597
598                    use crate::path_utils::resolve_homedir;
599
600                    let mut this = Self::default();
601                    let mut token = None;
602                    let verbose = config::verbose();
603
604                    for (i, (k, v)) in config.into_iter().enumerate() {
605                        let (k, v) = (k.as_ref(), v.into());
606
607                        if i == 0 && k == "token" {
608                            if verbose {
609                                eprintln!("HF token sourced from storage_options");
610                            }
611                            token = Some(v);
612                        } else {
613                            polars_bail!(ComputeError: "unknown configuration key for HF: {}", k)
614                        }
615                    }
616
617                    token = token
618                        .or_else(|| {
619                            let v = std::env::var("HF_TOKEN").ok();
620                            if v.is_some() && verbose {
621                                eprintln!("HF token sourced from HF_TOKEN env var");
622                            }
623                            v
624                        })
625                        .or_else(|| {
626                            let hf_home = std::env::var("HF_HOME");
627                            let hf_home = hf_home.as_deref();
628                            let hf_home = hf_home.unwrap_or("~/.cache/huggingface");
629                            let hf_home = resolve_homedir(&hf_home);
630                            let cached_token_path = hf_home.join("token");
631
632                            let v = std::string::String::from_utf8(
633                                std::fs::read(&cached_token_path).ok()?,
634                            )
635                            .ok()
636                            .filter(|x| !x.is_empty());
637
638                            if v.is_some() && verbose {
639                                eprintln!(
640                                    "HF token sourced from {}",
641                                    cached_token_path.to_str().unwrap()
642                                );
643                            }
644
645                            v
646                        });
647
648                    if let Some(v) = token {
649                        this.config = Some(CloudConfig::Http {
650                            headers: vec![("Authorization".into(), format!("Bearer {v}"))],
651                        })
652                    }
653
654                    Ok(this)
655                }
656                #[cfg(not(feature = "http"))]
657                {
658                    polars_bail!(ComputeError: "'http' feature is not enabled");
659                }
660            },
661        }
662    }
663
664    /// Python passes a credential provider builder that needs to be called to get the actual credential
665    /// provider.
666    #[cfg(feature = "cloud")]
667    fn initialized_credential_provider(
668        &self,
669        clear_cached_credentials: bool,
670    ) -> PolarsResult<Option<PlCredentialProvider>> {
671        if let Some(v) = self.credential_provider.clone() {
672            v.try_into_initialized(clear_cached_credentials)
673        } else {
674            Ok(None)
675        }
676    }
677}
678
679#[cfg(feature = "cloud")]
680#[cfg(test)]
681mod tests {
682    use hashbrown::HashMap;
683
684    use super::parse_untyped_config;
685
686    #[cfg(feature = "aws")]
687    #[test]
688    fn test_parse_untyped_config() {
689        use object_store::aws::AmazonS3ConfigKey;
690
691        let aws_config = [
692            ("aws_secret_access_key", "a_key"),
693            ("aws_s3_allow_unsafe_rename", "true"),
694        ]
695        .into_iter()
696        .collect::<HashMap<_, _>>();
697        let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
698            .expect("Parsing keys shouldn't have thrown an error");
699
700        assert_eq!(
701            aws_keys.first().unwrap().0,
702            AmazonS3ConfigKey::SecretAccessKey
703        );
704        assert_eq!(aws_keys.len(), 1);
705
706        let aws_config = [
707            ("AWS_SECRET_ACCESS_KEY", "a_key"),
708            ("aws_s3_allow_unsafe_rename", "true"),
709        ]
710        .into_iter()
711        .collect::<HashMap<_, _>>();
712        let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
713            .expect("Parsing keys shouldn't have thrown an error");
714
715        assert_eq!(
716            aws_keys.first().unwrap().0,
717            AmazonS3ConfigKey::SecretAccessKey
718        );
719        assert_eq!(aws_keys.len(), 1);
720    }
721}