1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use async_trait::async_trait;
9#[cfg(feature = "aws")]
10pub use object_store::aws::AwsCredential;
11#[cfg(feature = "azure")]
12pub use object_store::azure::AzureCredential;
13#[cfg(feature = "gcp")]
14pub use object_store::gcp::GcpCredential;
15use polars_core::config;
16use polars_error::{PolarsResult, polars_bail};
17#[cfg(feature = "python")]
18use polars_utils::python_function::PythonObject;
19#[cfg(feature = "python")]
20use python_impl::PythonCredentialProvider;
21
22#[derive(Clone, Debug, PartialEq, Hash, Eq)]
23pub enum PlCredentialProvider {
24 Function(CredentialProviderFunction),
26 #[cfg(feature = "python")]
27 Python(python_impl::PythonCredentialProvider),
28}
29
30impl PlCredentialProvider {
31 pub fn from_func(
35 func: impl Fn() -> Pin<
38 Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
39 > + Send
40 + Sync
41 + 'static,
42 ) -> Self {
43 Self::Function(CredentialProviderFunction(Arc::new(func)))
44 }
45
46 #[cfg(feature = "python")]
49 pub fn from_python_builder(func: pyo3::PyObject) -> Self {
50 Self::Python(python_impl::PythonCredentialProvider::Builder(Arc::new(
51 PythonObject(func),
52 )))
53 }
54
55 pub(super) fn func_addr(&self) -> usize {
56 match self {
57 Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize,
58 #[cfg(feature = "python")]
59 Self::Python(v) => v.func_addr(),
60 }
61 }
62
63 pub(crate) fn try_into_initialized(self) -> PolarsResult<Option<Self>> {
68 match self {
69 Self::Function(_) => Ok(Some(self)),
70 #[cfg(feature = "python")]
71 Self::Python(v) => Ok(v.try_into_initialized()?.map(Self::Python)),
72 }
73 }
74}
75
76pub enum ObjectStoreCredential {
77 #[cfg(feature = "aws")]
78 Aws(Arc<object_store::aws::AwsCredential>),
79 #[cfg(feature = "azure")]
80 Azure(Arc<object_store::azure::AzureCredential>),
81 #[cfg(feature = "gcp")]
82 Gcp(Arc<object_store::gcp::GcpCredential>),
83 None,
85}
86
87impl ObjectStoreCredential {
88 fn variant_name(&self) -> &'static str {
89 match self {
90 #[cfg(feature = "aws")]
91 Self::Aws(_) => "Aws",
92 #[cfg(feature = "azure")]
93 Self::Azure(_) => "Azure",
94 #[cfg(feature = "gcp")]
95 Self::Gcp(_) => "Gcp",
96 Self::None => "None",
97 }
98 }
99
100 fn panic_type_mismatch(&self, expected: &str) {
101 panic!(
102 "impl error: credential type mismatch: expected {}, got {} instead",
103 expected,
104 self.variant_name()
105 )
106 }
107
108 #[cfg(feature = "aws")]
109 fn unwrap_aws(self) -> Arc<object_store::aws::AwsCredential> {
110 let Self::Aws(v) = self else {
111 self.panic_type_mismatch("aws");
112 unreachable!()
113 };
114 v
115 }
116
117 #[cfg(feature = "azure")]
118 fn unwrap_azure(self) -> Arc<object_store::azure::AzureCredential> {
119 let Self::Azure(v) = self else {
120 self.panic_type_mismatch("azure");
121 unreachable!()
122 };
123 v
124 }
125
126 #[cfg(feature = "gcp")]
127 fn unwrap_gcp(self) -> Arc<object_store::gcp::GcpCredential> {
128 let Self::Gcp(v) = self else {
129 self.panic_type_mismatch("gcp");
130 unreachable!()
131 };
132 v
133 }
134}
135
136pub trait IntoCredentialProvider: Sized {
137 #[cfg(feature = "aws")]
138 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
139 unimplemented!()
140 }
141
142 #[cfg(feature = "azure")]
143 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
144 unimplemented!()
145 }
146
147 #[cfg(feature = "gcp")]
148 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
149 unimplemented!()
150 }
151}
152
153impl IntoCredentialProvider for PlCredentialProvider {
154 #[cfg(feature = "aws")]
155 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
156 match self {
157 Self::Function(v) => v.into_aws_provider(),
158 #[cfg(feature = "python")]
159 Self::Python(v) => v.into_aws_provider(),
160 }
161 }
162
163 #[cfg(feature = "azure")]
164 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
165 match self {
166 Self::Function(v) => v.into_azure_provider(),
167 #[cfg(feature = "python")]
168 Self::Python(v) => v.into_azure_provider(),
169 }
170 }
171
172 #[cfg(feature = "gcp")]
173 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
174 match self {
175 Self::Function(v) => v.into_gcp_provider(),
176 #[cfg(feature = "python")]
177 Self::Python(v) => v.into_gcp_provider(),
178 }
179 }
180}
181
182type CredentialProviderFunctionImpl = Arc<
183 dyn Fn() -> Pin<
184 Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
185 > + Send
186 + Sync,
187>;
188
189#[derive(Clone)]
191pub struct CredentialProviderFunction(CredentialProviderFunctionImpl);
192
193macro_rules! build_to_object_store_err {
194 ($s:expr) => {{
195 fn to_object_store_err(
196 e: impl std::error::Error + Send + Sync + 'static,
197 ) -> object_store::Error {
198 object_store::Error::Generic {
199 store: $s,
200 source: Box::new(e),
201 }
202 }
203
204 to_object_store_err
205 }};
206}
207
208impl IntoCredentialProvider for CredentialProviderFunction {
209 #[cfg(feature = "aws")]
210 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
211 #[derive(Debug)]
212 struct S(
213 CredentialProviderFunction,
214 FetchedCredentialsCache<Arc<object_store::aws::AwsCredential>>,
215 );
216
217 #[async_trait]
218 impl object_store::CredentialProvider for S {
219 type Credential = object_store::aws::AwsCredential;
220
221 async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
222 self.1
223 .get_maybe_update(async {
224 let (creds, expiry) = self.0.0().await?;
225 PolarsResult::Ok((creds.unwrap_aws(), expiry))
226 })
227 .await
228 .map_err(build_to_object_store_err!("credential-provider-aws"))
229 }
230 }
231
232 Arc::new(S(
233 self,
234 FetchedCredentialsCache::new(Arc::new(AwsCredential {
235 key_id: String::new(),
236 secret_key: String::new(),
237 token: None,
238 })),
239 ))
240 }
241
242 #[cfg(feature = "azure")]
243 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
244 #[derive(Debug)]
245 struct S(
246 CredentialProviderFunction,
247 FetchedCredentialsCache<Arc<object_store::azure::AzureCredential>>,
248 );
249
250 #[async_trait]
251 impl object_store::CredentialProvider for S {
252 type Credential = object_store::azure::AzureCredential;
253
254 async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
255 self.1
256 .get_maybe_update(async {
257 let (creds, expiry) = self.0.0().await?;
258 PolarsResult::Ok((creds.unwrap_azure(), expiry))
259 })
260 .await
261 .map_err(build_to_object_store_err!("credential-provider-azure"))
262 }
263 }
264
265 Arc::new(S(
266 self,
267 FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))),
268 ))
269 }
270
271 #[cfg(feature = "gcp")]
272 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
273 #[derive(Debug)]
274 struct S(
275 CredentialProviderFunction,
276 FetchedCredentialsCache<Arc<object_store::gcp::GcpCredential>>,
277 );
278
279 #[async_trait]
280 impl object_store::CredentialProvider for S {
281 type Credential = object_store::gcp::GcpCredential;
282
283 async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
284 self.1
285 .get_maybe_update(async {
286 let (creds, expiry) = self.0.0().await?;
287 PolarsResult::Ok((creds.unwrap_gcp(), expiry))
288 })
289 .await
290 .map_err(build_to_object_store_err!("credential-provider-gcp"))
291 }
292 }
293
294 Arc::new(S(
295 self,
296 FetchedCredentialsCache::new(Arc::new(GcpCredential {
297 bearer: String::new(),
298 })),
299 ))
300 }
301}
302
303impl Debug for CredentialProviderFunction {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 write!(
306 f,
307 "credential provider function at 0x{:016x}",
308 self.0.as_ref() as *const _ as *const () as usize
309 )
310 }
311}
312
313impl Eq for CredentialProviderFunction {}
314
315impl PartialEq for CredentialProviderFunction {
316 fn eq(&self, other: &Self) -> bool {
317 Arc::ptr_eq(&self.0, &other.0)
318 }
319}
320
321impl Hash for CredentialProviderFunction {
322 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
323 state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
324 }
325}
326
327#[cfg(feature = "serde")]
328impl<'de> serde::Deserialize<'de> for PlCredentialProvider {
329 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
330 where
331 D: serde::Deserializer<'de>,
332 {
333 #[cfg(feature = "python")]
334 {
335 Ok(Self::Python(PythonCredentialProvider::deserialize(
336 _deserializer,
337 )?))
338 }
339 #[cfg(not(feature = "python"))]
340 {
341 use serde::de::Error;
342 Err(D::Error::custom("cannot deserialize PlCredentialProvider"))
343 }
344 }
345}
346
347#[cfg(feature = "serde")]
348impl serde::Serialize for PlCredentialProvider {
349 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
350 where
351 S: serde::Serializer,
352 {
353 use serde::ser::Error;
354
355 #[cfg(feature = "python")]
356 if let PlCredentialProvider::Python(v) = self {
357 return v.serialize(_serializer);
358 }
359
360 Err(S::Error::custom(format!("cannot serialize {:?}", self)))
361 }
362}
363
364#[derive(Debug)]
366struct FetchedCredentialsCache<C>(tokio::sync::Mutex<(C, u64)>);
367
368impl<C: Clone> FetchedCredentialsCache<C> {
369 fn new(init_creds: C) -> Self {
370 Self(tokio::sync::Mutex::new((init_creds, 0)))
371 }
372
373 async fn get_maybe_update(
374 &self,
375 update_func: impl Future<Output = PolarsResult<(C, u64)>>,
379 ) -> PolarsResult<C> {
380 let verbose = config::verbose();
381
382 fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String {
383 if last_fetched_expiry == u64::MAX {
384 "expiry = (never expires)".into()
385 } else {
386 format!(
387 "expiry = {} (in {} seconds)",
388 last_fetched_expiry,
389 last_fetched_expiry.saturating_sub(now)
390 )
391 }
392 }
393
394 let mut inner = self.0.lock().await;
395 let (last_fetched_credentials, last_fetched_expiry) = &mut *inner;
396
397 let current_time = SystemTime::now()
398 .duration_since(UNIX_EPOCH)
399 .unwrap()
400 .as_secs();
401
402 const REQUEST_TIME_BUFFER: u64 = 7;
405
406 if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER {
407 if verbose {
408 eprintln!(
409 "[FetchedCredentialsCache]: Call update_func: current_time = {}\
410 , last_fetched_expiry = {}",
411 current_time, *last_fetched_expiry
412 )
413 }
414 let (credentials, expiry) = update_func.await?;
415
416 *last_fetched_credentials = credentials;
417 *last_fetched_expiry = expiry;
418
419 if expiry < current_time && expiry != 0 {
420 polars_bail!(
421 ComputeError:
422 "credential expiry time {} is older than system time {} \
423 by {} seconds",
424 expiry,
425 current_time,
426 current_time - expiry
427 )
428 }
429
430 if verbose {
431 eprintln!(
432 "[FetchedCredentialsCache]: Finish update_func: new {}",
433 expiry_msg(
434 *last_fetched_expiry,
435 SystemTime::now()
436 .duration_since(UNIX_EPOCH)
437 .unwrap()
438 .as_secs()
439 )
440 )
441 }
442 } else if verbose {
443 let now = SystemTime::now()
444 .duration_since(UNIX_EPOCH)
445 .unwrap()
446 .as_secs();
447 eprintln!(
448 "[FetchedCredentialsCache]: Using cached credentials: \
449 current_time = {}, {}",
450 now,
451 expiry_msg(*last_fetched_expiry, now)
452 )
453 }
454
455 Ok(last_fetched_credentials.clone())
456 }
457}
458
459#[cfg(feature = "python")]
460mod python_impl {
461 use std::hash::Hash;
462 use std::sync::Arc;
463
464 use polars_error::{PolarsError, PolarsResult, to_compute_err};
465 use polars_utils::python_function::PythonObject;
466 use pyo3::Python;
467 use pyo3::exceptions::PyValueError;
468 use pyo3::pybacked::PyBackedStr;
469 use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
470
471 use super::IntoCredentialProvider;
472
473 #[derive(Clone, Debug)]
474 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
475 pub enum PythonCredentialProvider {
476 #[cfg_attr(
477 feature = "serde",
478 serde(
479 serialize_with = "PythonObject::serialize_with_pyversion",
480 deserialize_with = "PythonObject::deserialize_with_pyversion"
481 )
482 )]
483 Builder(Arc<PythonObject>),
485 #[cfg_attr(
486 feature = "serde",
487 serde(
488 serialize_with = "PythonObject::serialize_with_pyversion",
489 deserialize_with = "PythonObject::deserialize_with_pyversion"
490 )
491 )]
492 Provider(Arc<PythonObject>),
494 }
495
496 impl PythonCredentialProvider {
497 pub(super) fn try_into_initialized(self) -> PolarsResult<Option<Self>> {
503 match self {
504 Self::Builder(py_object) => {
505 let opt_initialized_py_object = Python::with_gil(|py| {
506 let build_fn = py_object.getattr(py, "build_credential_provider")?;
507
508 let v = build_fn.call0(py)?;
509 let v = (!v.is_none(py)).then_some(v);
510
511 pyo3::PyResult::Ok(v)
512 })
513 .map_err(to_compute_err)?;
514
515 Ok(opt_initialized_py_object
516 .map(PythonObject)
517 .map(Arc::new)
518 .map(Self::Provider))
519 },
520 Self::Provider(_) => {
521 Ok(Some(self))
523 },
524 }
525 }
526
527 fn unwrap_as_provider(self) -> Arc<PythonObject> {
528 match self {
529 Self::Builder(_) => panic!(),
530 Self::Provider(v) => v,
531 }
532 }
533
534 pub(super) fn func_addr(&self) -> usize {
535 (match self {
536 Self::Builder(v) => Arc::as_ptr(v),
537 Self::Provider(v) => Arc::as_ptr(v),
538 }) as *const () as usize
539 }
540 }
541
542 impl IntoCredentialProvider for PythonCredentialProvider {
543 #[cfg(feature = "aws")]
544 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
545 use polars_error::{PolarsResult, to_compute_err};
546
547 use crate::cloud::credential_provider::{
548 CredentialProviderFunction, ObjectStoreCredential,
549 };
550
551 let func = self.unwrap_as_provider();
552
553 CredentialProviderFunction(Arc::new(move || {
554 let func = func.clone();
555 Box::pin(async move {
556 let mut credentials = object_store::aws::AwsCredential {
557 key_id: String::new(),
558 secret_key: String::new(),
559 token: None,
560 };
561
562 let expiry = Python::with_gil(|py| {
563 let v = func.0.call0(py)?.into_bound(py);
564 let (storage_options, expiry) =
565 v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
566
567 for (k, v) in storage_options.iter() {
568 let k = k.extract::<PyBackedStr>()?;
569 let v = v.extract::<Option<String>>()?;
570
571 match k.as_ref() {
572 "aws_access_key_id" => {
573 credentials.key_id = v.ok_or_else(|| {
574 PyValueError::new_err("aws_access_key_id was None")
575 })?;
576 },
577 "aws_secret_access_key" => {
578 credentials.secret_key = v.ok_or_else(|| {
579 PyValueError::new_err("aws_secret_access_key was None")
580 })?
581 },
582 "aws_session_token" => credentials.token = v,
583 v => {
584 return pyo3::PyResult::Err(PyValueError::new_err(format!(
585 "unknown configuration key for aws: {}, \
586 valid configuration keys are: \
587 {}, {}, {}",
588 v,
589 "aws_access_key_id",
590 "aws_secret_access_key",
591 "aws_session_token"
592 )));
593 },
594 }
595 }
596
597 pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
598 })
599 .map_err(to_compute_err)?;
600
601 if credentials.key_id.is_empty() {
602 return Err(PolarsError::ComputeError(
603 "aws_access_key_id was empty or not given".into(),
604 ));
605 }
606
607 if credentials.secret_key.is_empty() {
608 return Err(PolarsError::ComputeError(
609 "aws_secret_access_key was empty or not given".into(),
610 ));
611 }
612
613 PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry))
614 })
615 }))
616 .into_aws_provider()
617 }
618
619 #[cfg(feature = "azure")]
620 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
621 use object_store::azure::AzureAccessKey;
622 use polars_error::{PolarsResult, to_compute_err};
623
624 use crate::cloud::credential_provider::{
625 CredentialProviderFunction, ObjectStoreCredential,
626 };
627
628 let func = self.unwrap_as_provider();
629
630 CredentialProviderFunction(Arc::new(move || {
631 let func = func.clone();
632 Box::pin(async move {
633 let mut credentials = None;
634
635 static VALID_KEYS_MSG: &str =
636 "valid configuration keys are: account_key, bearer_token";
637
638 let expiry = Python::with_gil(|py| {
639 let v = func.0.call0(py)?.into_bound(py);
640 let (storage_options, expiry) =
641 v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
642
643 for (k, v) in storage_options.iter() {
644 let k = k.extract::<PyBackedStr>()?;
645 let v = v.extract::<String>()?;
646
647 match k.as_ref() {
648 "account_key" => {
649 credentials =
650 Some(object_store::azure::AzureCredential::AccessKey(
651 AzureAccessKey::try_new(v.as_str()).map_err(|e| {
652 PyValueError::new_err(e.to_string())
653 })?,
654 ))
655 },
656 "bearer_token" => {
657 credentials =
658 Some(object_store::azure::AzureCredential::BearerToken(v))
659 },
660 v => {
661 return pyo3::PyResult::Err(PyValueError::new_err(format!(
662 "unknown configuration key for azure: {}, {}",
663 v, VALID_KEYS_MSG
664 )));
665 },
666 }
667 }
668
669 pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
670 })
671 .map_err(to_compute_err)?;
672
673 let Some(credentials) = credentials else {
674 return Err(PolarsError::ComputeError(
675 format!(
676 "did not find a valid configuration key for azure, {}",
677 VALID_KEYS_MSG
678 )
679 .into(),
680 ));
681 };
682
683 PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry))
684 })
685 }))
686 .into_azure_provider()
687 }
688
689 #[cfg(feature = "gcp")]
690 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
691 use polars_error::{PolarsResult, to_compute_err};
692
693 use crate::cloud::credential_provider::{
694 CredentialProviderFunction, ObjectStoreCredential,
695 };
696
697 let func = self.unwrap_as_provider();
698
699 CredentialProviderFunction(Arc::new(move || {
700 let func = func.clone();
701 Box::pin(async move {
702 let mut credentials = object_store::gcp::GcpCredential {
703 bearer: String::new(),
704 };
705
706 let expiry = Python::with_gil(|py| {
707 let v = func.0.call0(py)?.into_bound(py);
708 let (storage_options, expiry) =
709 v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
710
711 for (k, v) in storage_options.iter() {
712 let k = k.extract::<PyBackedStr>()?;
713 let v = v.extract::<String>()?;
714
715 match k.as_ref() {
716 "bearer_token" => credentials.bearer = v,
717 v => {
718 return pyo3::PyResult::Err(PyValueError::new_err(format!(
719 "unknown configuration key for gcp: {}, \
720 valid configuration keys are: {}",
721 v, "bearer_token",
722 )));
723 },
724 }
725 }
726
727 pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
728 })
729 .map_err(to_compute_err)?;
730
731 if credentials.bearer.is_empty() {
732 return Err(PolarsError::ComputeError(
733 "bearer was empty or not given".into(),
734 ));
735 }
736
737 PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry))
738 })
739 }))
740 .into_gcp_provider()
741 }
742 }
743
744 impl Eq for PythonCredentialProvider {}
748
749 impl PartialEq for PythonCredentialProvider {
750 fn eq(&self, other: &Self) -> bool {
751 self.func_addr() == other.func_addr()
752 }
753 }
754
755 impl Hash for PythonCredentialProvider {
756 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
757 state.write_usize(self.func_addr())
762 }
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 #[cfg(feature = "serde")]
769 #[allow(clippy::redundant_pattern_matching)]
770 #[test]
771 fn test_serde() {
772 use super::*;
773
774 assert!(matches!(
775 serde_json::to_string(&Some(PlCredentialProvider::from_func(|| {
776 Box::pin(core::future::ready(PolarsResult::Ok((
777 ObjectStoreCredential::None,
778 0,
779 ))))
780 }))),
781 Err(_)
782 ));
783
784 assert!(matches!(
785 serde_json::to_string(&Option::<PlCredentialProvider>::None),
786 Ok(String { .. })
787 ));
788
789 assert!(matches!(
790 serde_json::from_str::<Option<PlCredentialProvider>>(
791 serde_json::to_string(&Option::<PlCredentialProvider>::None)
792 .unwrap()
793 .as_str()
794 ),
795 Ok(None)
796 ));
797 }
798}