1#[cfg(feature = "aws")]
2use std::io::Read;
3#[cfg(feature = "aws")]
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::LazyLock;
7
8#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
9use object_store::ClientOptions;
10#[cfg(feature = "aws")]
11use object_store::aws::AmazonS3Builder;
12#[cfg(feature = "aws")]
13pub use object_store::aws::AmazonS3ConfigKey;
14#[cfg(feature = "azure")]
15pub use object_store::azure::AzureConfigKey;
16#[cfg(feature = "azure")]
17use object_store::azure::MicrosoftAzureBuilder;
18#[cfg(feature = "gcp")]
19use object_store::gcp::GoogleCloudStorageBuilder;
20#[cfg(feature = "gcp")]
21pub use object_store::gcp::GoogleConfigKey;
22#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
23use object_store::{BackoffConfig, RetryConfig};
24use polars_error::*;
25#[cfg(feature = "aws")]
26use polars_utils::cache::LruCache;
27#[cfg(feature = "http")]
28use reqwest::header::HeaderMap;
29#[cfg(feature = "serde")]
30use serde::{Deserialize, Serialize};
31#[cfg(feature = "cloud")]
32use url::Url;
33
34#[cfg(feature = "cloud")]
35use super::credential_provider::PlCredentialProvider;
36#[cfg(feature = "file_cache")]
37use crate::file_cache::get_env_file_cache_ttl;
38#[cfg(feature = "aws")]
39use crate::pl_async::with_concurrency_budget;
40
41#[cfg(feature = "aws")]
42static BUCKET_REGION: LazyLock<
43 std::sync::Mutex<LruCache<polars_utils::pl_str::PlSmallStr, polars_utils::pl_str::PlSmallStr>>,
44> = LazyLock::new(|| std::sync::Mutex::new(LruCache::with_capacity(32)));
45
46#[allow(dead_code)]
53type Configs<T> = Vec<(T, String)>;
54
55#[derive(Clone, Debug, PartialEq, Hash, Eq)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57pub(crate) enum CloudConfig {
58 #[cfg(feature = "aws")]
59 Aws(Configs<AmazonS3ConfigKey>),
60 #[cfg(feature = "azure")]
61 Azure(Configs<AzureConfigKey>),
62 #[cfg(feature = "gcp")]
63 Gcp(Configs<GoogleConfigKey>),
64 #[cfg(feature = "http")]
65 Http { headers: Vec<(String, String)> },
66}
67
68#[derive(Clone, Debug, PartialEq, Hash, Eq)]
69#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70pub struct CloudOptions {
72 pub max_retries: usize,
73 #[cfg(feature = "file_cache")]
74 pub file_cache_ttl: u64,
75 pub(crate) config: Option<CloudConfig>,
76 #[cfg(feature = "cloud")]
77 pub(crate) credential_provider: Option<PlCredentialProvider>,
80}
81
82impl Default for CloudOptions {
83 fn default() -> Self {
84 Self::default_static_ref().clone()
85 }
86}
87
88impl CloudOptions {
89 pub fn default_static_ref() -> &'static Self {
90 static DEFAULT: LazyLock<CloudOptions> = LazyLock::new(|| CloudOptions {
91 max_retries: 2,
92 #[cfg(feature = "file_cache")]
93 file_cache_ttl: get_env_file_cache_ttl(),
94 config: None,
95 #[cfg(feature = "cloud")]
96 credential_provider: None,
97 });
98
99 &DEFAULT
100 }
101}
102
103#[cfg(feature = "http")]
104pub(crate) fn try_build_http_header_map_from_items_slice<S: AsRef<str>>(
105 headers: &[(S, S)],
106) -> PolarsResult<HeaderMap> {
107 use reqwest::header::{HeaderName, HeaderValue};
108
109 let mut map = HeaderMap::with_capacity(headers.len());
110 for (k, v) in headers {
111 let (k, v) = (k.as_ref(), v.as_ref());
112 map.insert(
113 HeaderName::from_str(k).map_err(to_compute_err)?,
114 HeaderValue::from_str(v).map_err(to_compute_err)?,
115 );
116 }
117
118 Ok(map)
119}
120
121#[allow(dead_code)]
122fn parsed_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
124 config: I,
125) -> PolarsResult<Configs<T>>
126where
127 T: FromStr + Eq + std::hash::Hash,
128{
129 Ok(config
130 .into_iter()
131 .filter_map(|(key, val)| {
133 T::from_str(key.as_ref().to_ascii_lowercase().as_str())
134 .ok()
135 .map(|typed_key| (typed_key, val.into()))
136 })
137 .collect::<Configs<T>>())
138}
139
140#[derive(Debug, Clone, PartialEq)]
141pub enum CloudType {
142 Aws,
143 Azure,
144 File,
145 Gcp,
146 Http,
147 Hf,
148}
149
150impl CloudType {
151 #[cfg(feature = "cloud")]
152 pub(crate) fn from_url(parsed: &Url) -> PolarsResult<Self> {
153 Ok(match parsed.scheme() {
154 "s3" | "s3a" => Self::Aws,
155 "az" | "azure" | "adl" | "abfs" | "abfss" => Self::Azure,
156 "gs" | "gcp" | "gcs" => Self::Gcp,
157 "file" => Self::File,
158 "http" | "https" => Self::Http,
159 "hf" => Self::Hf,
160 _ => polars_bail!(ComputeError: "unknown url scheme"),
161 })
162 }
163}
164
165#[cfg(feature = "cloud")]
166pub(crate) fn parse_url(input: &str) -> std::result::Result<url::Url, url::ParseError> {
167 Ok(if input.contains("://") {
168 if input.starts_with("http://") || input.starts_with("https://") {
169 url::Url::parse(input)
170 } else {
171 url::Url::parse(&input.replace("%", "%25"))
172 }?
173 } else {
174 let path = std::path::Path::new(input);
175 let mut tmp;
176 url::Url::from_file_path(if path.is_relative() {
177 tmp = std::env::current_dir().unwrap();
178 tmp.push(path);
179 tmp.as_path()
180 } else {
181 path
182 })
183 .unwrap()
184 })
185}
186
187impl FromStr for CloudType {
188 type Err = PolarsError;
189
190 #[cfg(feature = "cloud")]
191 fn from_str(url: &str) -> Result<Self, Self::Err> {
192 let parsed = parse_url(url).map_err(to_compute_err)?;
193 Self::from_url(&parsed)
194 }
195
196 #[cfg(not(feature = "cloud"))]
197 fn from_str(_s: &str) -> Result<Self, Self::Err> {
198 polars_bail!(ComputeError: "at least one of the cloud features must be enabled");
199 }
200}
201#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
202fn get_retry_config(max_retries: usize) -> RetryConfig {
203 RetryConfig {
204 backoff: BackoffConfig::default(),
205 max_retries,
206 retry_timeout: std::time::Duration::from_secs(10),
207 }
208}
209
210#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
211pub(super) fn get_client_options() -> ClientOptions {
212 ClientOptions::new()
213 .with_timeout_disabled()
217 .with_connect_timeout_disabled()
219 .with_allow_http(true)
220}
221
222#[cfg(feature = "aws")]
223fn read_config(
224 builder: &mut AmazonS3Builder,
225 items: &[(&Path, &[(&str, AmazonS3ConfigKey)])],
226) -> Option<()> {
227 use crate::path_utils::resolve_homedir;
228
229 for (path, keys) in items {
230 if keys
231 .iter()
232 .all(|(_, key)| builder.get_config_value(key).is_some())
233 {
234 continue;
235 }
236
237 let mut config = std::fs::File::open(resolve_homedir(path)).ok()?;
238 let mut buf = vec![];
239 config.read_to_end(&mut buf).ok()?;
240 let content = std::str::from_utf8(buf.as_ref()).ok()?;
241
242 for (pattern, key) in keys.iter() {
243 if builder.get_config_value(key).is_none() {
244 let reg = polars_utils::regex_cache::compile_regex(pattern).unwrap();
245 let cap = reg.captures(content)?;
246 let m = cap.get(1)?;
247 let parsed = m.as_str();
248 *builder = std::mem::take(builder).with_config(*key, parsed);
249 }
250 }
251 }
252 Some(())
253}
254
255impl CloudOptions {
256 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
258 self.max_retries = max_retries;
259 self
260 }
261
262 #[cfg(feature = "cloud")]
263 pub fn with_credential_provider(
264 mut self,
265 credential_provider: Option<PlCredentialProvider>,
266 ) -> Self {
267 self.credential_provider = credential_provider;
268 self
269 }
270
271 #[cfg(feature = "aws")]
273 pub fn with_aws<I: IntoIterator<Item = (AmazonS3ConfigKey, impl Into<String>)>>(
274 mut self,
275 configs: I,
276 ) -> Self {
277 self.config = Some(CloudConfig::Aws(
278 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
279 ));
280 self
281 }
282
283 #[cfg(feature = "aws")]
285 pub async fn build_aws(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
286 use super::credential_provider::IntoCredentialProvider;
287
288 let mut builder = AmazonS3Builder::from_env()
289 .with_client_options(get_client_options())
290 .with_url(url);
291
292 read_config(
293 &mut builder,
294 &[(
295 Path::new("~/.aws/config"),
296 &[("region\\s*=\\s*([^\r\n]*)", AmazonS3ConfigKey::Region)],
297 )],
298 );
299
300 read_config(
301 &mut builder,
302 &[(
303 Path::new("~/.aws/credentials"),
304 &[
305 (
306 "aws_access_key_id\\s*=\\s*([^\\r\\n]*)",
307 AmazonS3ConfigKey::AccessKeyId,
308 ),
309 (
310 "aws_secret_access_key\\s*=\\s*([^\\r\\n]*)",
311 AmazonS3ConfigKey::SecretAccessKey,
312 ),
313 (
314 "aws_session_token\\s*=\\s*([^\\r\\n]*)",
315 AmazonS3ConfigKey::Token,
316 ),
317 ],
318 )],
319 );
320
321 if let Some(options) = &self.config {
322 let CloudConfig::Aws(options) = options else {
323 panic!("impl error: cloud type mismatch")
324 };
325 for (key, value) in options.iter() {
326 builder = builder.with_config(*key, value);
327 }
328 }
329
330 if builder
331 .get_config_value(&AmazonS3ConfigKey::DefaultRegion)
332 .is_none()
333 && builder
334 .get_config_value(&AmazonS3ConfigKey::Region)
335 .is_none()
336 {
337 let bucket = crate::cloud::CloudLocation::new(url, false)?.bucket;
338 let region = {
339 let mut bucket_region = BUCKET_REGION.lock().unwrap();
340 bucket_region.get(bucket.as_str()).cloned()
341 };
342
343 match region {
344 Some(region) => {
345 builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str())
346 },
347 None => {
348 if builder
349 .get_config_value(&AmazonS3ConfigKey::Endpoint)
350 .is_some()
351 {
352 builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1");
355 } else {
356 polars_warn!(
357 "'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning."
358 );
359 let result = with_concurrency_budget(1, || async {
360 reqwest::Client::builder()
361 .build()
362 .unwrap()
363 .head(format!("https://{bucket}.s3.amazonaws.com"))
364 .send()
365 .await
366 .map_err(to_compute_err)
367 })
368 .await?;
369 if let Some(region) = result.headers().get("x-amz-bucket-region") {
370 let region =
371 std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?;
372 let mut bucket_region = BUCKET_REGION.lock().unwrap();
373 bucket_region.insert(bucket, region.into());
374 builder = builder.with_config(AmazonS3ConfigKey::Region, region)
375 }
376 }
377 },
378 };
379 };
380
381 let builder = builder.with_retry(get_retry_config(self.max_retries));
382
383 let builder = if let Some(v) = self.initialized_credential_provider()? {
384 builder.with_credentials(v.into_aws_provider())
385 } else {
386 builder
387 };
388
389 builder.build().map_err(to_compute_err)
390 }
391
392 #[cfg(feature = "azure")]
394 pub fn with_azure<I: IntoIterator<Item = (AzureConfigKey, impl Into<String>)>>(
395 mut self,
396 configs: I,
397 ) -> Self {
398 self.config = Some(CloudConfig::Azure(
399 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
400 ));
401 self
402 }
403
404 #[cfg(feature = "azure")]
406 pub fn build_azure(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
407 use super::credential_provider::IntoCredentialProvider;
408
409 let verbose = polars_core::config::verbose();
410
411 let mut builder =
414 MicrosoftAzureBuilder::from_env().with_client_options(get_client_options());
415
416 if let Some(options) = &self.config {
417 let CloudConfig::Azure(options) = options else {
418 panic!("impl error: cloud type mismatch")
419 };
420 for (key, value) in options.iter() {
421 builder = builder.with_config(*key, value);
422 }
423 }
424
425 let builder = builder
426 .with_url(url)
427 .with_retry(get_retry_config(self.max_retries));
428
429 let builder = if let Some(v) = self.initialized_credential_provider()? {
430 if verbose {
431 eprintln!(
432 "[CloudOptions::build_azure]: Using credential provider {:?}",
433 &v
434 );
435 }
436 builder.with_credentials(v.into_azure_provider())
437 } else {
438 builder
439 };
440
441 builder.build().map_err(to_compute_err)
442 }
443
444 #[cfg(feature = "gcp")]
446 pub fn with_gcp<I: IntoIterator<Item = (GoogleConfigKey, impl Into<String>)>>(
447 mut self,
448 configs: I,
449 ) -> Self {
450 self.config = Some(CloudConfig::Gcp(
451 configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
452 ));
453 self
454 }
455
456 #[cfg(feature = "gcp")]
458 pub fn build_gcp(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
459 use super::credential_provider::IntoCredentialProvider;
460
461 let credential_provider = self.initialized_credential_provider()?;
462
463 let builder = if credential_provider.is_none() {
464 GoogleCloudStorageBuilder::from_env()
465 } else {
466 GoogleCloudStorageBuilder::new()
467 };
468
469 let mut builder = builder.with_client_options(get_client_options());
470
471 if let Some(options) = &self.config {
472 let CloudConfig::Gcp(options) = options else {
473 panic!("impl error: cloud type mismatch")
474 };
475 for (key, value) in options.iter() {
476 builder = builder.with_config(*key, value);
477 }
478 }
479
480 let builder = builder
481 .with_url(url)
482 .with_retry(get_retry_config(self.max_retries));
483
484 let builder = if let Some(v) = credential_provider.clone() {
485 builder.with_credentials(v.into_gcp_provider())
486 } else {
487 builder
488 };
489
490 builder.build().map_err(to_compute_err)
491 }
492
493 #[cfg(feature = "http")]
494 pub fn build_http(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
495 object_store::http::HttpBuilder::new()
496 .with_url(url)
497 .with_client_options({
498 let mut opts = super::get_client_options();
499 if let Some(CloudConfig::Http { headers }) = &self.config {
500 opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
501 headers.as_slice(),
502 )?);
503 }
504 opts
505 })
506 .build()
507 .map_err(to_compute_err)
508 }
509
510 #[allow(unused_variables)]
512 pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
513 url: &str,
514 config: I,
515 ) -> PolarsResult<Self> {
516 match CloudType::from_str(url)? {
517 CloudType::Aws => {
518 #[cfg(feature = "aws")]
519 {
520 parsed_untyped_config::<AmazonS3ConfigKey, _>(config)
521 .map(|aws| Self::default().with_aws(aws))
522 }
523 #[cfg(not(feature = "aws"))]
524 {
525 polars_bail!(ComputeError: "'aws' feature is not enabled");
526 }
527 },
528 CloudType::Azure => {
529 #[cfg(feature = "azure")]
530 {
531 parsed_untyped_config::<AzureConfigKey, _>(config)
532 .map(|azure| Self::default().with_azure(azure))
533 }
534 #[cfg(not(feature = "azure"))]
535 {
536 polars_bail!(ComputeError: "'azure' feature is not enabled");
537 }
538 },
539 CloudType::File => Ok(Self::default()),
540 CloudType::Http => Ok(Self::default()),
541 CloudType::Gcp => {
542 #[cfg(feature = "gcp")]
543 {
544 parsed_untyped_config::<GoogleConfigKey, _>(config)
545 .map(|gcp| Self::default().with_gcp(gcp))
546 }
547 #[cfg(not(feature = "gcp"))]
548 {
549 polars_bail!(ComputeError: "'gcp' feature is not enabled");
550 }
551 },
552 CloudType::Hf => {
553 #[cfg(feature = "http")]
554 {
555 use polars_core::config;
556
557 use crate::path_utils::resolve_homedir;
558
559 let mut this = Self::default();
560 let mut token = None;
561 let verbose = config::verbose();
562
563 for (i, (k, v)) in config.into_iter().enumerate() {
564 let (k, v) = (k.as_ref(), v.into());
565
566 if i == 0 && k == "token" {
567 if verbose {
568 eprintln!("HF token sourced from storage_options");
569 }
570 token = Some(v);
571 } else {
572 polars_bail!(ComputeError: "unknown configuration key for HF: {}", k)
573 }
574 }
575
576 token = token
577 .or_else(|| {
578 let v = std::env::var("HF_TOKEN").ok();
579 if v.is_some() && verbose {
580 eprintln!("HF token sourced from HF_TOKEN env var");
581 }
582 v
583 })
584 .or_else(|| {
585 let hf_home = std::env::var("HF_HOME");
586 let hf_home = hf_home.as_deref();
587 let hf_home = hf_home.unwrap_or("~/.cache/huggingface");
588 let hf_home = resolve_homedir(&hf_home);
589 let cached_token_path = hf_home.join("token");
590
591 let v = std::string::String::from_utf8(
592 std::fs::read(&cached_token_path).ok()?,
593 )
594 .ok()
595 .filter(|x| !x.is_empty());
596
597 if v.is_some() && verbose {
598 eprintln!(
599 "HF token sourced from {}",
600 cached_token_path.to_str().unwrap()
601 );
602 }
603
604 v
605 });
606
607 if let Some(v) = token {
608 this.config = Some(CloudConfig::Http {
609 headers: vec![("Authorization".into(), format!("Bearer {}", v))],
610 })
611 }
612
613 Ok(this)
614 }
615 #[cfg(not(feature = "http"))]
616 {
617 polars_bail!(ComputeError: "'http' feature is not enabled");
618 }
619 },
620 }
621 }
622
623 #[cfg(feature = "cloud")]
626 fn initialized_credential_provider(&self) -> PolarsResult<Option<PlCredentialProvider>> {
627 if let Some(v) = self.credential_provider.clone() {
628 v.try_into_initialized()
629 } else {
630 Ok(None)
631 }
632 }
633}
634
635#[cfg(feature = "cloud")]
636#[cfg(test)]
637mod tests {
638 use hashbrown::HashMap;
639
640 use super::{parse_url, parsed_untyped_config};
641
642 #[test]
643 fn test_parse_url() {
644 assert_eq!(
645 parse_url(r"http://Users/Jane Doe/data.csv")
646 .unwrap()
647 .as_str(),
648 "http://users/Jane%20Doe/data.csv"
649 );
650 assert_eq!(
651 parse_url(r"http://Users/Jane Doe/data.csv")
652 .unwrap()
653 .as_str(),
654 "http://users/Jane%20Doe/data.csv"
655 );
656 #[cfg(target_os = "windows")]
657 {
658 assert_eq!(
659 parse_url(r"file:///c:/Users/Jane Doe/data.csv")
660 .unwrap()
661 .as_str(),
662 "file:///c:/Users/Jane%20Doe/data.csv"
663 );
664 assert_eq!(
665 parse_url(r"file://\c:\Users\Jane Doe\data.csv")
666 .unwrap()
667 .as_str(),
668 "file:///c:/Users/Jane%20Doe/data.csv"
669 );
670 assert_eq!(
671 parse_url(r"c:\Users\Jane Doe\data.csv").unwrap().as_str(),
672 "file:///C:/Users/Jane%20Doe/data.csv"
673 );
674 assert_eq!(
675 parse_url(r"data.csv").unwrap().as_str(),
676 url::Url::from_file_path(
677 [
678 std::env::current_dir().unwrap().as_path(),
679 std::path::Path::new("data.csv")
680 ]
681 .into_iter()
682 .collect::<std::path::PathBuf>()
683 )
684 .unwrap()
685 .as_str()
686 );
687 }
688 #[cfg(not(target_os = "windows"))]
689 {
690 assert_eq!(
691 parse_url(r"file:///home/Jane Doe/data.csv")
692 .unwrap()
693 .as_str(),
694 "file:///home/Jane%20Doe/data.csv"
695 );
696 assert_eq!(
697 parse_url(r"/home/Jane Doe/data.csv").unwrap().as_str(),
698 "file:///home/Jane%20Doe/data.csv"
699 );
700 assert_eq!(
701 parse_url(r"data.csv").unwrap().as_str(),
702 url::Url::from_file_path(
703 [
704 std::env::current_dir().unwrap().as_path(),
705 std::path::Path::new("data.csv")
706 ]
707 .into_iter()
708 .collect::<std::path::PathBuf>()
709 )
710 .unwrap()
711 .as_str()
712 );
713 }
714 }
715 #[cfg(feature = "aws")]
716 #[test]
717 fn test_parse_untyped_config() {
718 use object_store::aws::AmazonS3ConfigKey;
719
720 let aws_config = [
721 ("aws_secret_access_key", "a_key"),
722 ("aws_s3_allow_unsafe_rename", "true"),
723 ]
724 .into_iter()
725 .collect::<HashMap<_, _>>();
726 let aws_keys = parsed_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
727 .expect("Parsing keys shouldn't have thrown an error");
728
729 assert_eq!(
730 aws_keys.first().unwrap().0,
731 AmazonS3ConfigKey::SecretAccessKey
732 );
733 assert_eq!(aws_keys.len(), 1);
734
735 let aws_config = [
736 ("AWS_SECRET_ACCESS_KEY", "a_key"),
737 ("aws_s3_allow_unsafe_rename", "true"),
738 ]
739 .into_iter()
740 .collect::<HashMap<_, _>>();
741 let aws_keys = parsed_untyped_config::<AmazonS3ConfigKey, _>(aws_config)
742 .expect("Parsing keys shouldn't have thrown an error");
743
744 assert_eq!(
745 aws_keys.first().unwrap().0,
746 AmazonS3ConfigKey::SecretAccessKey
747 );
748 assert_eq!(aws_keys.len(), 1);
749 }
750}