#[cfg(feature = "aws")]
use std::io::Read;
#[cfg(feature = "aws")]
use std::path::Path;
use std::str::FromStr;
#[cfg(feature = "aws")]
use object_store::aws::AmazonS3Builder;
#[cfg(feature = "aws")]
pub use object_store::aws::AmazonS3ConfigKey;
#[cfg(feature = "azure")]
pub use object_store::azure::AzureConfigKey;
#[cfg(feature = "azure")]
use object_store::azure::MicrosoftAzureBuilder;
#[cfg(feature = "gcp")]
use object_store::gcp::GoogleCloudStorageBuilder;
#[cfg(feature = "gcp")]
pub use object_store::gcp::GoogleConfigKey;
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
use object_store::ClientOptions;
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
use object_store::{BackoffConfig, RetryConfig};
#[cfg(feature = "aws")]
use once_cell::sync::Lazy;
use polars_error::*;
#[cfg(feature = "aws")]
use polars_utils::cache::FastFixedCache;
#[cfg(feature = "aws")]
use regex::Regex;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "aws")]
use smartstring::alias::String as SmartString;
#[cfg(feature = "cloud")]
use url::Url;
#[cfg(feature = "aws")]
use crate::pl_async::with_concurrency_budget;
#[cfg(feature = "aws")]
use crate::utils::resolve_homedir;
#[cfg(feature = "aws")]
static BUCKET_REGION: Lazy<std::sync::Mutex<FastFixedCache<SmartString, SmartString>>> =
Lazy::new(|| std::sync::Mutex::new(FastFixedCache::new(32)));
#[allow(dead_code)]
type Configs<T> = Vec<(T, String)>;
#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CloudOptions {
#[cfg(feature = "aws")]
aws: Option<Configs<AmazonS3ConfigKey>>,
#[cfg(feature = "azure")]
azure: Option<Configs<AzureConfigKey>>,
#[cfg(feature = "gcp")]
gcp: Option<Configs<GoogleConfigKey>>,
pub max_retries: usize,
}
impl Default for CloudOptions {
fn default() -> Self {
Self {
max_retries: 2,
#[cfg(feature = "aws")]
aws: Default::default(),
#[cfg(feature = "azure")]
azure: Default::default(),
#[cfg(feature = "gcp")]
gcp: Default::default(),
}
}
}
#[allow(dead_code)]
fn parsed_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
config: I,
) -> PolarsResult<Configs<T>>
where
T: FromStr + Eq + std::hash::Hash,
{
config
.into_iter()
.map(|(key, val)| {
T::from_str(key.as_ref())
.map_err(
|_| polars_err!(ComputeError: "unknown configuration key: {}", key.as_ref()),
)
.map(|typed_key| (typed_key, val.into()))
})
.collect::<PolarsResult<Configs<T>>>()
}
#[derive(PartialEq)]
pub enum CloudType {
Aws,
Azure,
File,
Gcp,
Http,
}
impl CloudType {
#[cfg(feature = "cloud")]
pub(crate) fn from_url(parsed: &Url) -> PolarsResult<Self> {
Ok(match parsed.scheme() {
"s3" | "s3a" => Self::Aws,
"az" | "azure" | "adl" | "abfs" | "abfss" => Self::Azure,
"gs" | "gcp" | "gcs" => Self::Gcp,
"file" => Self::File,
"http" | "https" => Self::Http,
_ => polars_bail!(ComputeError: "unknown url scheme"),
})
}
}
#[cfg(feature = "cloud")]
pub(crate) fn parse_url(input: &str) -> std::result::Result<url::Url, url::ParseError> {
Ok(if input.contains("://") {
url::Url::parse(input)?
} else {
let path = std::path::Path::new(input);
let mut tmp;
url::Url::from_file_path(if path.is_relative() {
tmp = std::env::current_dir().unwrap();
tmp.push(path);
tmp.as_path()
} else {
path
})
.unwrap()
})
}
impl FromStr for CloudType {
type Err = PolarsError;
#[cfg(feature = "cloud")]
fn from_str(url: &str) -> Result<Self, Self::Err> {
let parsed = parse_url(url).map_err(to_compute_err)?;
Self::from_url(&parsed)
}
#[cfg(not(feature = "cloud"))]
fn from_str(_s: &str) -> Result<Self, Self::Err> {
polars_bail!(ComputeError: "at least one of the cloud features must be enabled");
}
}
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
fn get_retry_config(max_retries: usize) -> RetryConfig {
RetryConfig {
backoff: BackoffConfig::default(),
max_retries,
retry_timeout: std::time::Duration::from_secs(10),
}
}
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
pub(super) fn get_client_options() -> ClientOptions {
ClientOptions::default()
.with_timeout_disabled()
.with_connect_timeout_disabled()
.with_allow_http(true)
}
#[cfg(feature = "aws")]
fn read_config(
builder: &mut AmazonS3Builder,
items: &[(&Path, &[(&str, AmazonS3ConfigKey)])],
) -> Option<()> {
for (path, keys) in items {
if keys
.iter()
.all(|(_, key)| builder.get_config_value(key).is_some())
{
continue;
}
let mut config = std::fs::File::open(resolve_homedir(path)).ok()?;
let mut buf = vec![];
config.read_to_end(&mut buf).ok()?;
let content = std::str::from_utf8(buf.as_ref()).ok()?;
for (pattern, key) in keys.iter() {
let local = std::mem::take(builder);
if builder.get_config_value(key).is_none() {
let reg = Regex::new(pattern).unwrap();
let cap = reg.captures(content)?;
let m = cap.get(1)?;
let parsed = m.as_str();
*builder = local.with_config(*key, parsed)
}
}
}
Some(())
}
impl CloudOptions {
#[cfg(feature = "aws")]
pub fn with_aws<I: IntoIterator<Item = (AmazonS3ConfigKey, impl Into<String>)>>(
mut self,
configs: I,
) -> Self {
self.aws = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<AmazonS3ConfigKey>>(),
);
self
}
#[cfg(feature = "aws")]
pub async fn build_aws(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.aws.as_ref();
let mut builder = AmazonS3Builder::from_env().with_url(url);
if let Some(options) = options {
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
}
read_config(
&mut builder,
&[(
Path::new("~/.aws/config"),
&[("region = (.*)\n", AmazonS3ConfigKey::Region)],
)],
);
read_config(
&mut builder,
&[(
Path::new("~/.aws/credentials"),
&[
("aws_access_key_id = (.*)\n", AmazonS3ConfigKey::AccessKeyId),
(
"aws_secret_access_key = (.*)\n",
AmazonS3ConfigKey::SecretAccessKey,
),
],
)],
);
if builder
.get_config_value(&AmazonS3ConfigKey::DefaultRegion)
.is_none()
&& builder
.get_config_value(&AmazonS3ConfigKey::Region)
.is_none()
{
let bucket = crate::cloud::CloudLocation::new(url)?.bucket;
let region = {
let bucket_region = BUCKET_REGION.lock().unwrap();
bucket_region.get(bucket.as_str()).cloned()
};
match region {
Some(region) => {
builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str())
},
None => {
if builder
.get_config_value(&AmazonS3ConfigKey::Endpoint)
.is_some()
{
builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1");
} else {
polars_warn!("'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning.");
let result = with_concurrency_budget(1, || async {
reqwest::Client::builder()
.build()
.unwrap()
.head(format!("https://{bucket}.s3.amazonaws.com"))
.send()
.await
.map_err(to_compute_err)
})
.await?;
if let Some(region) = result.headers().get("x-amz-bucket-region") {
let region =
std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?;
let mut bucket_region = BUCKET_REGION.lock().unwrap();
bucket_region.insert(bucket.into(), region.into());
builder = builder.with_config(AmazonS3ConfigKey::Region, region)
}
}
},
};
};
builder
.with_client_options(get_client_options())
.with_retry(get_retry_config(self.max_retries))
.build()
.map_err(to_compute_err)
}
#[cfg(feature = "azure")]
pub fn with_azure<I: IntoIterator<Item = (AzureConfigKey, impl Into<String>)>>(
mut self,
configs: I,
) -> Self {
self.azure = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<AzureConfigKey>>(),
);
self
}
#[cfg(feature = "azure")]
pub fn build_azure(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.azure.as_ref();
let mut builder = MicrosoftAzureBuilder::from_env();
if let Some(options) = options {
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
}
builder
.with_client_options(get_client_options())
.with_url(url)
.with_retry(get_retry_config(self.max_retries))
.build()
.map_err(to_compute_err)
}
#[cfg(feature = "gcp")]
pub fn with_gcp<I: IntoIterator<Item = (GoogleConfigKey, impl Into<String>)>>(
mut self,
configs: I,
) -> Self {
self.gcp = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<GoogleConfigKey>>(),
);
self
}
#[cfg(feature = "gcp")]
pub fn build_gcp(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.gcp.as_ref();
let mut builder = GoogleCloudStorageBuilder::from_env();
if let Some(options) = options {
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
}
builder
.with_client_options(get_client_options())
.with_url(url)
.with_retry(get_retry_config(self.max_retries))
.build()
.map_err(to_compute_err)
}
#[allow(unused_variables)]
pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
url: &str,
config: I,
) -> PolarsResult<Self> {
match CloudType::from_str(url)? {
CloudType::Aws => {
#[cfg(feature = "aws")]
{
parsed_untyped_config::<AmazonS3ConfigKey, _>(config)
.map(|aws| Self::default().with_aws(aws))
}
#[cfg(not(feature = "aws"))]
{
polars_bail!(ComputeError: "'aws' feature is not enabled");
}
},
CloudType::Azure => {
#[cfg(feature = "azure")]
{
parsed_untyped_config::<AzureConfigKey, _>(config)
.map(|azure| Self::default().with_azure(azure))
}
#[cfg(not(feature = "azure"))]
{
polars_bail!(ComputeError: "'azure' feature is not enabled");
}
},
CloudType::File => Ok(Self::default()),
CloudType::Http => Ok(Self::default()),
CloudType::Gcp => {
#[cfg(feature = "gcp")]
{
parsed_untyped_config::<GoogleConfigKey, _>(config)
.map(|gcp| Self::default().with_gcp(gcp))
}
#[cfg(not(feature = "gcp"))]
{
polars_bail!(ComputeError: "'gcp' feature is not enabled");
}
},
}
}
}
#[cfg(feature = "cloud")]
#[cfg(test)]
mod tests {
use super::parse_url;
#[test]
fn test_parse_url() {
assert_eq!(
parse_url(r"http://Users/Jane Doe/data.csv")
.unwrap()
.as_str(),
"http://users/Jane%20Doe/data.csv"
);
assert_eq!(
parse_url(r"http://Users/Jane Doe/data.csv")
.unwrap()
.as_str(),
"http://users/Jane%20Doe/data.csv"
);
#[cfg(target_os = "windows")]
{
assert_eq!(
parse_url(r"file:///c:/Users/Jane Doe/data.csv")
.unwrap()
.as_str(),
"file:///c:/Users/Jane%20Doe/data.csv"
);
assert_eq!(
parse_url(r"file://\c:\Users\Jane Doe\data.csv")
.unwrap()
.as_str(),
"file:///c:/Users/Jane%20Doe/data.csv"
);
assert_eq!(
parse_url(r"c:\Users\Jane Doe\data.csv").unwrap().as_str(),
"file:///C:/Users/Jane%20Doe/data.csv"
);
assert_eq!(
parse_url(r"data.csv").unwrap().as_str(),
url::Url::from_file_path(
[
std::env::current_dir().unwrap().as_path(),
std::path::Path::new("data.csv")
]
.into_iter()
.collect::<std::path::PathBuf>()
)
.unwrap()
.as_str()
);
}
#[cfg(not(target_os = "windows"))]
{
assert_eq!(
parse_url(r"file:///home/Jane Doe/data.csv")
.unwrap()
.as_str(),
"file:///home/Jane%20Doe/data.csv"
);
assert_eq!(
parse_url(r"/home/Jane Doe/data.csv").unwrap().as_str(),
"file:///home/Jane%20Doe/data.csv"
);
assert_eq!(
parse_url(r"data.csv").unwrap().as_str(),
url::Url::from_file_path(
[
std::env::current_dir().unwrap().as_path(),
std::path::Path::new("data.csv")
]
.into_iter()
.collect::<std::path::PathBuf>()
)
.unwrap()
.as_str()
);
}
}
}