polars_io/cloud/
options.rs

1#[cfg(feature = "aws")]
2use std::io::Read;
3#[cfg(feature = "aws")]
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::LazyLock;
7
8#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
9use object_store::ClientOptions;
10#[cfg(feature = "aws")]
11use object_store::aws::AmazonS3Builder;
12#[cfg(feature = "aws")]
13pub use object_store::aws::AmazonS3ConfigKey;
14#[cfg(feature = "azure")]
15pub use object_store::azure::AzureConfigKey;
16#[cfg(feature = "azure")]
17use object_store::azure::MicrosoftAzureBuilder;
18#[cfg(feature = "gcp")]
19use object_store::gcp::GoogleCloudStorageBuilder;
20#[cfg(feature = "gcp")]
21pub use object_store::gcp::GoogleConfigKey;
22use polars_error::*;
23#[cfg(feature = "aws")]
24use polars_utils::cache::LruCache;
25use polars_utils::pl_path::{CloudScheme, PlRefPath};
26use polars_utils::total_ord::TotalOrdWrap;
27#[cfg(feature = "http")]
28use reqwest::header::HeaderMap;
29#[cfg(feature = "serde")]
30use serde::{Deserialize, Serialize};
31
32#[cfg(feature = "cloud")]
33use super::credential_provider::PlCredentialProvider;
34#[cfg(feature = "cloud")]
35use crate::cloud::ObjectStoreErrorContext;
36#[cfg(feature = "file_cache")]
37use crate::file_cache::get_env_file_cache_ttl;
38#[cfg(feature = "aws")]
39use crate::pl_async::with_concurrency_budget;
40
41#[cfg(feature = "aws")]
42static BUCKET_REGION: LazyLock<
43    std::sync::Mutex<LruCache<polars_utils::pl_str::PlSmallStr, polars_utils::pl_str::PlSmallStr>>,
44> = LazyLock::new(|| std::sync::Mutex::new(LruCache::with_capacity(32)));
45
46/// The type of the config keys must satisfy the following requirements:
47/// 1. must be easily collected into a HashMap, the type required by the object_crate API.
48/// 2. be Serializable, required when the serde-lazy feature is defined.
49/// 3. not actually use HashMap since that type is disallowed in Polars for performance reasons.
50///
51/// Currently this type is a vector of pairs config key - config value.
52#[allow(dead_code)]
53type Configs<T> = Vec<(T, String)>;
54
55#[derive(Clone, Debug, PartialEq, Hash, Eq)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
58pub(crate) enum CloudConfig {
59    #[cfg(feature = "aws")]
60    Aws(
61        #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
62        Configs<AmazonS3ConfigKey>,
63    ),
64    #[cfg(feature = "azure")]
65    Azure(
66        #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
67        Configs<AzureConfigKey>,
68    ),
69    #[cfg(feature = "gcp")]
70    Gcp(
71        #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<(String, String)>"))]
72        Configs<GoogleConfigKey>,
73    ),
74    #[cfg(feature = "http")]
75    Http { headers: Vec<(String, String)> },
76}
77
78#[derive(Clone, Debug, PartialEq, Hash, Eq)]
79#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
80#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
81/// Options to connect to various cloud providers.
82pub struct CloudOptions {
83    #[cfg(feature = "file_cache")]
84    pub file_cache_ttl: u64,
85    pub(crate) config: Option<CloudConfig>,
86    #[cfg_attr(feature = "serde", serde(default))]
87    pub retry_config: CloudRetryConfig,
88    #[cfg(feature = "cloud")]
89    /// Note: In most cases you will want to access this via [`CloudOptions::initialized_credential_provider`]
90    /// rather than directly.
91    pub(crate) credential_provider: Option<PlCredentialProvider>,
92}
93
94impl Default for CloudOptions {
95    fn default() -> Self {
96        Self::default_static_ref().clone()
97    }
98}
99
100impl CloudOptions {
101    pub fn default_static_ref() -> &'static Self {
102        static DEFAULT: LazyLock<CloudOptions> = LazyLock::new(|| CloudOptions {
103            #[cfg(feature = "file_cache")]
104            file_cache_ttl: get_env_file_cache_ttl(),
105            config: None,
106            retry_config: CloudRetryConfig::default(),
107            #[cfg(feature = "cloud")]
108            credential_provider: None,
109        });
110
111        &DEFAULT
112    }
113}
114
115#[derive(Clone, Copy, Default, Debug, PartialEq, Hash, Eq)]
116#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
117#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
118pub struct CloudRetryConfig {
119    pub max_retries: Option<usize>,
120    pub retry_timeout: Option<std::time::Duration>,
121    pub retry_init_backoff: Option<std::time::Duration>,
122    pub retry_max_backoff: Option<std::time::Duration>,
123    pub retry_base_multiplier: Option<TotalOrdWrap<f64>>,
124}
125
126#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
127impl From<CloudRetryConfig> for object_store::RetryConfig {
128    fn from(value: CloudRetryConfig) -> Self {
129        use std::time::Duration;
130
131        use polars_core::config::verbose;
132
133        let out = object_store::RetryConfig {
134            backoff: object_store::BackoffConfig {
135                init_backoff: value
136                    .retry_init_backoff
137                    .unwrap_or_else(|| DEFAULTS.backoff.init_backoff),
138                max_backoff: value
139                    .retry_max_backoff
140                    .unwrap_or_else(|| DEFAULTS.backoff.max_backoff),
141                base: value
142                    .retry_base_multiplier
143                    .map_or_else(|| DEFAULTS.backoff.base, |x| x.0),
144            },
145            max_retries: value.max_retries.unwrap_or_else(|| DEFAULTS.max_retries),
146            retry_timeout: value
147                .retry_timeout
148                .unwrap_or_else(|| DEFAULTS.retry_timeout),
149        };
150
151        if verbose() {
152            eprintln!("object-store retry config: {:?}", &out)
153        }
154
155        return out;
156
157        static DEFAULTS: LazyLock<object_store::RetryConfig> =
158            LazyLock::new(|| object_store::RetryConfig {
159                backoff: object_store::BackoffConfig {
160                    init_backoff: Duration::from_millis(parse_env_var(
161                        100,
162                        "POLARS_CLOUD_RETRY_INIT_BACKOFF_MS",
163                    )),
164                    max_backoff: Duration::from_millis(parse_env_var(
165                        15 * 1000,
166                        "POLARS_CLOUD_RETRY_MAX_BACKOFF_MS",
167                    )),
168                    base: parse_env_var(2., "POLARS_CLOUD_RETRY_BASE_MULTIPLIER"),
169                },
170                max_retries: parse_env_var(2, "POLARS_CLOUD_MAX_RETRIES"),
171                retry_timeout: Duration::from_millis(parse_env_var(
172                    10 * 1000,
173                    "POLARS_CLOUD_RETRY_TIMEOUT_MS",
174                )),
175            });
176
177        fn parse_env_var<T: FromStr>(default: T, name: &'static str) -> T {
178            std::env::var(name).map_or(default, |x| {
179                x.parse::<T>()
180                    .ok()
181                    .unwrap_or_else(|| panic!("invalid value for {name}: {x}"))
182            })
183        }
184    }
185}
186
187#[cfg(feature = "http")]
188pub(crate) fn try_build_http_header_map_from_items_slice<S: AsRef<str>>(
189    headers: &[(S, S)],
190) -> PolarsResult<HeaderMap> {
191    use reqwest::header::{HeaderName, HeaderValue};
192
193    let mut map = HeaderMap::with_capacity(headers.len());
194    for (k, v) in headers {
195        let (k, v) = (k.as_ref(), v.as_ref());
196        map.insert(
197            HeaderName::from_str(k).map_err(to_compute_err)?,
198            HeaderValue::from_str(v).map_err(to_compute_err)?,
199        );
200    }
201
202    Ok(map)
203}
204
205#[allow(dead_code)]
206/// Parse an untype configuration hashmap to a typed configuration for the given configuration key type.
207fn parse_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
208    config: I,
209) -> PolarsResult<Configs<T>>
210where
211    T: FromStr + Eq + std::hash::Hash,
212{
213    Ok(config
214        .into_iter()
215        // Silently ignores custom upstream storage_options
216        .filter_map(|(key, val)| {
217            T::from_str(key.as_ref().to_ascii_lowercase().as_str())
218                .ok()
219                .map(|typed_key| (typed_key, val.into()))
220        })
221        .collect::<Configs<T>>())
222}
223
224#[derive(Debug, Clone, PartialEq)]
225pub enum CloudType {
226    Aws,
227    Azure,
228    /// URI with 'file:' scheme
229    File,
230    /// Google cloud platform
231    Gcp,
232    Http,
233    /// HuggingFace
234    Hf,
235}
236
237impl CloudType {
238    pub fn from_cloud_scheme(scheme: CloudScheme) -> Self {
239        match scheme {
240            CloudScheme::Abfs
241            | CloudScheme::Abfss
242            | CloudScheme::Adl
243            | CloudScheme::Az
244            | CloudScheme::Azure => Self::Azure,
245
246            CloudScheme::File | CloudScheme::FileNoHostname => Self::File,
247
248            CloudScheme::Gcs | CloudScheme::Gs => Self::Gcp,
249
250            CloudScheme::Hf => Self::Hf,
251
252            CloudScheme::Http | CloudScheme::Https => Self::Http,
253
254            CloudScheme::S3 | CloudScheme::S3a => Self::Aws,
255        }
256    }
257}
258
259pub static USER_AGENT: &str = concat!("polars", "/", env!("CARGO_PKG_VERSION"),);
260
261#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
262pub(super) fn get_client_options() -> ClientOptions {
263    use std::num::NonZeroU64;
264
265    use reqwest::header::HeaderValue;
266
267    ClientOptions::new()
268        // Disables the time limit for downloading the response body.
269        .with_timeout_disabled()
270        // Set the time limit for establishing the connection.
271        .with_connect_timeout(std::time::Duration::from_secs(
272            std::env::var("POLARS_HTTP_CONNECT_TIMEOUT_SECONDS")
273                .map(|x| {
274                    x.parse::<NonZeroU64>()
275                        .ok()
276                        .unwrap_or_else(|| {
277                            panic!("invalid value for POLARS_HTTP_CONNECT_TIMEOUT_SECONDS: {x}")
278                        })
279                        .get()
280                })
281                .unwrap_or(5 * 60),
282        ))
283        .with_user_agent(HeaderValue::from_static(USER_AGENT))
284        .with_allow_http(true)
285}
286
287#[cfg(feature = "aws")]
288fn read_config(
289    builder: &mut AmazonS3Builder,
290    items: &[(&Path, &[(&str, AmazonS3ConfigKey)])],
291) -> Option<()> {
292    use crate::path_utils::resolve_homedir;
293
294    for (path, keys) in items {
295        if keys
296            .iter()
297            .all(|(_, key)| builder.get_config_value(key).is_some())
298        {
299            continue;
300        }
301
302        let mut config = std::fs::File::open(resolve_homedir(path)).ok()?;
303        let mut buf = vec![];
304        config.read_to_end(&mut buf).ok()?;
305        let content = std::str::from_utf8(buf.as_ref()).ok()?;
306
307        for (pattern, key) in keys.iter() {
308            if builder.get_config_value(key).is_none() {
309                let reg = polars_utils::regex_cache::compile_regex(pattern).unwrap();
310                let cap = reg.captures(content)?;
311                let m = cap.get(1)?;
312                let parsed = m.as_str();
313                *builder = std::mem::take(builder).with_config(*key, parsed);
314            }
315        }
316    }
317    Some(())
318}
319
320impl CloudOptions {
321    pub fn with_retry_config(mut self, retry_config: CloudRetryConfig) -> Self {
322        self.retry_config = retry_config;
323        self
324    }
325
326    #[cfg(feature = "cloud")]
327    pub fn with_credential_provider(
328        mut self,
329        credential_provider: Option<PlCredentialProvider>,
330    ) -> Self {
331        self.credential_provider = credential_provider;
332        self
333    }
334
335    /// Set the configuration for AWS connections. This is the preferred API from rust.
336    #[cfg(feature = "aws")]
337    pub fn with_aws<I: IntoIterator<Item = (AmazonS3ConfigKey, impl Into<String>)>>(
338        mut self,
339        configs: I,
340    ) -> Self {
341        self.config = Some(CloudConfig::Aws(
342            configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
343        ));
344        self
345    }
346
347    /// Build the [`object_store::ObjectStore`] implementation for AWS.
348    #[cfg(feature = "aws")]
349    pub async fn build_aws(
350        &self,
351        url: PlRefPath,
352        clear_cached_credentials: bool,
353    ) -> PolarsResult<impl object_store::ObjectStore> {
354        use super::credential_provider::IntoCredentialProvider;
355
356        let opt_credential_provider =
357            self.initialized_credential_provider(clear_cached_credentials)?;
358
359        let mut builder = AmazonS3Builder::from_env()
360            .with_client_options(get_client_options())
361            .with_url(url.clone().to_string());
362
363        if let Some(credential_provider) = &opt_credential_provider {
364            let storage_update_options = parse_untyped_config::<AmazonS3ConfigKey, _>(
365                credential_provider
366                    .storage_update_options()?
367                    .into_iter()
368                    .map(|(k, v)| (k, v.to_string())),
369            )?;
370
371            for (key, value) in storage_update_options {
372                builder = builder.with_config(key, value);
373            }
374        }
375
376        read_config(
377            &mut builder,
378            &[(
379                Path::new("~/.aws/config"),
380                &[("region\\s*=\\s*([^\r\n]*)", AmazonS3ConfigKey::Region)],
381            )],
382        );
383
384        read_config(
385            &mut builder,
386            &[(
387                Path::new("~/.aws/credentials"),
388                &[
389                    (
390                        "aws_access_key_id\\s*=\\s*([^\\r\\n]*)",
391                        AmazonS3ConfigKey::AccessKeyId,
392                    ),
393                    (
394                        "aws_secret_access_key\\s*=\\s*([^\\r\\n]*)",
395                        AmazonS3ConfigKey::SecretAccessKey,
396                    ),
397                    (
398                        "aws_session_token\\s*=\\s*([^\\r\\n]*)",
399                        AmazonS3ConfigKey::Token,
400                    ),
401                ],
402            )],
403        );
404
405        if let Some(options) = &self.config {
406            let CloudConfig::Aws(options) = options else {
407                panic!("impl error: cloud type mismatch")
408            };
409            for (key, value) in options {
410                builder = builder.with_config(*key, value);
411            }
412        }
413
414        if builder
415            .get_config_value(&AmazonS3ConfigKey::DefaultRegion)
416            .is_none()
417            && builder
418                .get_config_value(&AmazonS3ConfigKey::Region)
419                .is_none()
420        {
421            let bucket = crate::cloud::CloudLocation::new(url.clone(), false)?.bucket;
422            let region = {
423                let mut bucket_region = BUCKET_REGION.lock().unwrap();
424                bucket_region.get(bucket.as_str()).cloned()
425            };
426
427            match region {
428                Some(region) => {
429                    builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str())
430                },
431                None => {
432                    if builder
433                        .get_config_value(&AmazonS3ConfigKey::Endpoint)
434                        .is_some()
435                    {
436                        // Set a default value if the endpoint is not aws.
437                        // See: #13042
438                        builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1");
439                    } else {
440                        polars_warn!(
441                            "'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning."
442                        );
443                        let result = with_concurrency_budget(1, || async {
444                            reqwest::Client::builder()
445                                .user_agent(USER_AGENT)
446                                .build()
447                                .unwrap()
448                                .head(format!("https://{bucket}.s3.amazonaws.com"))
449                                .send()
450                                .await
451                                .map_err(to_compute_err)
452                        })
453                        .await?;
454                        if let Some(region) = result.headers().get("x-amz-bucket-region") {
455                            let region =
456                                std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?;
457                            let mut bucket_region = BUCKET_REGION.lock().unwrap();
458                            bucket_region.insert(bucket, region.into());
459                            builder = builder.with_config(AmazonS3ConfigKey::Region, region)
460                        }
461                    }
462                },
463            };
464        };
465
466        let builder = builder.with_retry(self.retry_config.into());
467
468        let opt_credential_provider = match opt_credential_provider {
469            #[cfg(feature = "python")]
470            Some(PlCredentialProvider::Python(object)) => {
471                if pyo3::Python::attach(|py| {
472                    let Ok(func_object) = object
473                        .unwrap_as_provider_ref()
474                        .getattr(py, "_can_use_as_provider")
475                    else {
476                        return PolarsResult::Ok(true);
477                    };
478
479                    Ok(func_object.call0(py)?.extract::<bool>(py).unwrap())
480                })? {
481                    Some(PlCredentialProvider::Python(object))
482                } else {
483                    None
484                }
485            },
486
487            v => v,
488        };
489
490        let builder = if let Some(credential_provider) = opt_credential_provider {
491            builder.with_credentials(credential_provider.into_aws_provider())
492        } else {
493            builder
494        };
495
496        let out = builder
497            .with_checksum_algorithm(object_store::aws::Checksum::CRC64NVME)
498            .with_unsigned_payload(true)
499            .build()
500            .map_err(|e| ObjectStoreErrorContext::new(url).attach_err_info(e))?;
501
502        Ok(out)
503    }
504
505    /// Set the configuration for Azure connections. This is the preferred API from rust.
506    #[cfg(feature = "azure")]
507    pub fn with_azure<I: IntoIterator<Item = (AzureConfigKey, impl Into<String>)>>(
508        mut self,
509        configs: I,
510    ) -> Self {
511        self.config = Some(CloudConfig::Azure(
512            configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
513        ));
514        self
515    }
516
517    /// Build the [`object_store::ObjectStore`] implementation for Azure.
518    #[cfg(feature = "azure")]
519    pub fn build_azure(
520        &self,
521        url: PlRefPath,
522        clear_cached_credentials: bool,
523    ) -> PolarsResult<impl object_store::ObjectStore> {
524        use super::credential_provider::IntoCredentialProvider;
525        use crate::cloud::ObjectStoreErrorContext;
526
527        let verbose = polars_core::config::verbose();
528
529        // The credential provider `self.credentials` is prioritized if it is set. We also need
530        // `from_env()` as it may source environment configured storage account name.
531        let mut builder =
532            MicrosoftAzureBuilder::from_env().with_client_options(get_client_options());
533
534        if let Some(options) = &self.config {
535            let CloudConfig::Azure(options) = options else {
536                panic!("impl error: cloud type mismatch")
537            };
538            for (key, value) in options.iter() {
539                builder = builder.with_config(*key, value);
540            }
541        }
542
543        let builder = builder
544            .with_url(url.to_string())
545            .with_retry(self.retry_config.into());
546
547        let builder =
548            if let Some(v) = self.initialized_credential_provider(clear_cached_credentials)? {
549                if verbose {
550                    eprintln!(
551                        "[CloudOptions::build_azure]: Using credential provider {:?}",
552                        &v
553                    );
554                }
555                builder.with_credentials(v.into_azure_provider())
556            } else {
557                builder
558            };
559
560        let out = builder
561            .build()
562            .map_err(|e| ObjectStoreErrorContext::new(url).attach_err_info(e))?;
563
564        Ok(out)
565    }
566
567    /// Set the configuration for GCP connections. This is the preferred API from rust.
568    #[cfg(feature = "gcp")]
569    pub fn with_gcp<I: IntoIterator<Item = (GoogleConfigKey, impl Into<String>)>>(
570        mut self,
571        configs: I,
572    ) -> Self {
573        self.config = Some(CloudConfig::Gcp(
574            configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
575        ));
576        self
577    }
578
579    /// Build the [`object_store::ObjectStore`] implementation for GCP.
580    #[cfg(feature = "gcp")]
581    pub fn build_gcp(
582        &self,
583        url: PlRefPath,
584        clear_cached_credentials: bool,
585    ) -> PolarsResult<impl object_store::ObjectStore> {
586        use super::credential_provider::IntoCredentialProvider;
587
588        let credential_provider = self.initialized_credential_provider(clear_cached_credentials)?;
589
590        let builder = if credential_provider.is_none() {
591            GoogleCloudStorageBuilder::from_env()
592        } else {
593            GoogleCloudStorageBuilder::new()
594        };
595
596        let mut builder = builder.with_client_options(get_client_options());
597
598        if let Some(options) = &self.config {
599            let CloudConfig::Gcp(options) = options else {
600                panic!("impl error: cloud type mismatch")
601            };
602            for (key, value) in options.iter() {
603                builder = builder.with_config(*key, value);
604            }
605        }
606
607        let builder = builder
608            .with_url(url.to_string())
609            .with_retry(self.retry_config.into());
610
611        let builder = if let Some(v) = credential_provider {
612            builder.with_credentials(v.into_gcp_provider())
613        } else {
614            builder
615        };
616
617        let out = builder
618            .build()
619            .map_err(|e| ObjectStoreErrorContext::new(url).attach_err_info(e))?;
620
621        Ok(out)
622    }
623
624    #[cfg(feature = "http")]
625    pub fn build_http(&self, url: PlRefPath) -> PolarsResult<impl object_store::ObjectStore> {
626        let out = object_store::http::HttpBuilder::new()
627            .with_url(url.to_string())
628            .with_client_options({
629                let mut opts = super::get_client_options();
630                if let Some(CloudConfig::Http { headers }) = &self.config {
631                    opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
632                        headers.as_slice(),
633                    )?);
634                }
635                opts
636            })
637            .build()
638            .map_err(|e| ObjectStoreErrorContext::new(url).attach_err_info(e))?;
639
640        Ok(out)
641    }
642
643    /// Parse a configuration from a Hashmap. This is the interface from Python.
644    #[allow(unused_variables)]
645    pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
646        scheme: Option<CloudScheme>,
647        config: I,
648    ) -> PolarsResult<Self> {
649        match scheme.map_or(CloudType::File, CloudType::from_cloud_scheme) {
650            CloudType::Aws => {
651                #[cfg(feature = "aws")]
652                {
653                    parse_untyped_config::<AmazonS3ConfigKey, _>(config)
654                        .map(|aws| Self::default().with_aws(aws))
655                }
656                #[cfg(not(feature = "aws"))]
657                {
658                    polars_bail!(ComputeError: "'aws' feature is not enabled");
659                }
660            },
661            CloudType::Azure => {
662                #[cfg(feature = "azure")]
663                {
664                    parse_untyped_config::<AzureConfigKey, _>(config)
665                        .map(|azure| Self::default().with_azure(azure))
666                }
667                #[cfg(not(feature = "azure"))]
668                {
669                    polars_bail!(ComputeError: "'azure' feature is not enabled");
670                }
671            },
672            CloudType::File => Ok(Self::default()),
673            CloudType::Http => Ok(Self::default()),
674            CloudType::Gcp => {
675                #[cfg(feature = "gcp")]
676                {
677                    parse_untyped_config::<GoogleConfigKey, _>(config)
678                        .map(|gcp| Self::default().with_gcp(gcp))
679                }
680                #[cfg(not(feature = "gcp"))]
681                {
682                    polars_bail!(ComputeError: "'gcp' feature is not enabled");
683                }
684            },
685            CloudType::Hf => {
686                #[cfg(feature = "http")]
687                {
688                    use polars_core::config;
689
690                    use crate::path_utils::resolve_homedir;
691
692                    let mut this = Self::default();
693                    let mut token = None;
694                    let verbose = config::verbose();
695
696                    for (i, (k, v)) in config.into_iter().enumerate() {
697                        let (k, v) = (k.as_ref(), v.into());
698
699                        if i == 0 && k == "token" {
700                            if verbose {
701                                eprintln!("HF token sourced from storage_options");
702                            }
703                            token = Some(v);
704                        } else {
705                            polars_bail!(ComputeError: "unknown configuration key for HF: {}", k)
706                        }
707                    }
708
709                    token = token
710                        .or_else(|| {
711                            let v = std::env::var("HF_TOKEN").ok();
712                            if v.is_some() && verbose {
713                                eprintln!("HF token sourced from HF_TOKEN env var");
714                            }
715                            v
716                        })
717                        .or_else(|| {
718                            let hf_home = std::env::var("HF_HOME");
719                            let hf_home = hf_home.as_deref();
720                            let hf_home = hf_home.unwrap_or("~/.cache/huggingface");
721                            let hf_home = resolve_homedir(hf_home);
722                            let cached_token_path = hf_home.join("token");
723
724                            let v = std::string::String::from_utf8(
725                                std::fs::read(&cached_token_path).ok()?,
726                            )
727                            .ok()
728                            .filter(|x| !x.is_empty());
729
730                            if v.is_some() && verbose {
731                                eprintln!("HF token sourced from {:?}", cached_token_path);
732                            }
733
734                            v
735                        });
736
737                    if let Some(v) = token {
738                        this.config = Some(CloudConfig::Http {
739                            headers: vec![("Authorization".into(), format!("Bearer {v}"))],
740                        })
741                    }
742
743                    Ok(this)
744                }
745                #[cfg(not(feature = "http"))]
746                {
747                    polars_bail!(ComputeError: "'http' feature is not enabled");
748                }
749            },
750        }
751    }
752
753    /// Python passes a credential provider builder that needs to be called to get the actual credential
754    /// provider.
755    #[cfg(feature = "cloud")]
756    fn initialized_credential_provider(
757        &self,
758        clear_cached_credentials: bool,
759    ) -> PolarsResult<Option<PlCredentialProvider>> {
760        if let Some(v) = self.credential_provider.clone() {
761            v.try_into_initialized(clear_cached_credentials)
762        } else {
763            Ok(None)
764        }
765    }
766}
767
768#[cfg(feature = "cloud")]
769#[cfg(test)]
770mod tests {
771    use hashbrown::HashMap;
772
773    use super::parse_untyped_config;
774
775    #[cfg(feature = "aws")]
776    #[test]
777    fn test_parse_untyped_config() {
778        use object_store::aws::AmazonS3ConfigKey;
779
780        let aws_config = [
781            ("aws_secret_access_key", "a_key"),
782            ("aws_s3_allow_unsafe_rename", "true"),
783        ]
784        .into_iter()
785        .collect::<HashMap<_, _>>();
786        let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
787            .expect("Parsing keys shouldn't have thrown an error");
788
789        assert_eq!(
790            aws_keys.first().unwrap().0,
791            AmazonS3ConfigKey::SecretAccessKey
792        );
793        assert_eq!(aws_keys.len(), 1);
794
795        let aws_config = [
796            ("AWS_SECRET_ACCESS_KEY", "a_key"),
797            ("aws_s3_allow_unsafe_rename", "true"),
798        ]
799        .into_iter()
800        .collect::<HashMap<_, _>>();
801        let aws_keys = parse_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
802            .expect("Parsing keys shouldn't have thrown an error");
803
804        assert_eq!(
805            aws_keys.first().unwrap().0,
806            AmazonS3ConfigKey::SecretAccessKey
807        );
808        assert_eq!(aws_keys.len(), 1);
809    }
810}