polars_io/cloud/
credential_provider.rs

1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use async_trait::async_trait;
9#[cfg(feature = "aws")]
10pub use object_store::aws::AwsCredential;
11#[cfg(feature = "azure")]
12pub use object_store::azure::AzureCredential;
13#[cfg(feature = "gcp")]
14pub use object_store::gcp::GcpCredential;
15use polars_core::config;
16use polars_error::{PolarsResult, polars_bail};
17use polars_utils::pl_str::PlSmallStr;
18#[cfg(feature = "python")]
19use polars_utils::python_function::PythonObject;
20#[cfg(feature = "python")]
21use python_impl::PythonCredentialProvider;
22
23#[derive(Clone, Debug, PartialEq, Hash, Eq)]
24pub enum PlCredentialProvider {
25    /// Prefer using [`PlCredentialProvider::from_func`] instead of constructing this directly
26    Function(CredentialProviderFunction),
27    #[cfg(feature = "python")]
28    Python(PythonCredentialProvider),
29}
30
31impl PlCredentialProvider {
32    /// Accepts a function that returns (credential, expiry time as seconds since UNIX_EPOCH)
33    ///
34    /// This functionality is unstable.
35    pub fn from_func(
36        // Internal notes
37        // * This function is exposed as the Rust API for `PlCredentialProvider`
38        func: impl Fn() -> Pin<
39            Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
40        > + Send
41        + Sync
42        + 'static,
43    ) -> Self {
44        Self::Function(CredentialProviderFunction(Arc::new(func)))
45    }
46
47    /// Intended to be called with an internal `CredentialProviderBuilder` from
48    /// py-polars.
49    #[cfg(feature = "python")]
50    pub fn from_python_builder(func: pyo3::PyObject) -> Self {
51        Self::Python(python_impl::PythonCredentialProvider::Builder(Arc::new(
52            PythonObject(func),
53        )))
54    }
55
56    pub(super) fn func_addr(&self) -> usize {
57        match self {
58            Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize,
59            #[cfg(feature = "python")]
60            Self::Python(v) => v.func_addr(),
61        }
62    }
63
64    /// Python passes a `CredentialProviderBuilder`, this calls the builder to build the final
65    /// credential provider.
66    ///
67    /// This returns `Option` as the auto-initialization case is fallible and falls back to None.
68    pub(crate) fn try_into_initialized(
69        self,
70        clear_cached_credentials: bool,
71    ) -> PolarsResult<Option<Self>> {
72        match self {
73            Self::Function(_) => Ok(Some(self)),
74            #[cfg(feature = "python")]
75            Self::Python(v) => Ok(v
76                .try_into_initialized(clear_cached_credentials)?
77                .map(Self::Python)),
78        }
79    }
80}
81
82pub enum ObjectStoreCredential {
83    #[cfg(feature = "aws")]
84    Aws(Arc<object_store::aws::AwsCredential>),
85    #[cfg(feature = "azure")]
86    Azure(Arc<object_store::azure::AzureCredential>),
87    #[cfg(feature = "gcp")]
88    Gcp(Arc<object_store::gcp::GcpCredential>),
89    /// For testing purposes
90    None,
91}
92
93impl ObjectStoreCredential {
94    fn variant_name(&self) -> &'static str {
95        match self {
96            #[cfg(feature = "aws")]
97            Self::Aws(_) => "Aws",
98            #[cfg(feature = "azure")]
99            Self::Azure(_) => "Azure",
100            #[cfg(feature = "gcp")]
101            Self::Gcp(_) => "Gcp",
102            Self::None => "None",
103        }
104    }
105
106    fn panic_type_mismatch(&self, expected: &str) {
107        panic!(
108            "impl error: credential type mismatch: expected {}, got {} instead",
109            expected,
110            self.variant_name()
111        )
112    }
113
114    #[cfg(feature = "aws")]
115    fn unwrap_aws(self) -> Arc<object_store::aws::AwsCredential> {
116        let Self::Aws(v) = self else {
117            self.panic_type_mismatch("aws");
118            unreachable!()
119        };
120        v
121    }
122
123    #[cfg(feature = "azure")]
124    fn unwrap_azure(self) -> Arc<object_store::azure::AzureCredential> {
125        let Self::Azure(v) = self else {
126            self.panic_type_mismatch("azure");
127            unreachable!()
128        };
129        v
130    }
131
132    #[cfg(feature = "gcp")]
133    fn unwrap_gcp(self) -> Arc<object_store::gcp::GcpCredential> {
134        let Self::Gcp(v) = self else {
135            self.panic_type_mismatch("gcp");
136            unreachable!()
137        };
138        v
139    }
140}
141
142pub trait IntoCredentialProvider: Sized {
143    #[cfg(feature = "aws")]
144    fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
145        unimplemented!()
146    }
147
148    #[cfg(feature = "azure")]
149    fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
150        unimplemented!()
151    }
152
153    #[cfg(feature = "gcp")]
154    fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
155        unimplemented!()
156    }
157
158    /// Note, technically shouldn't be under the `IntoCredentialProvider` trait, but it's here
159    /// for convenience.
160    fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>>;
161}
162
163impl IntoCredentialProvider for PlCredentialProvider {
164    #[cfg(feature = "aws")]
165    fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
166        match self {
167            Self::Function(v) => v.into_aws_provider(),
168            #[cfg(feature = "python")]
169            Self::Python(v) => v.into_aws_provider(),
170        }
171    }
172
173    #[cfg(feature = "azure")]
174    fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
175        match self {
176            Self::Function(v) => v.into_azure_provider(),
177            #[cfg(feature = "python")]
178            Self::Python(v) => v.into_azure_provider(),
179        }
180    }
181
182    #[cfg(feature = "gcp")]
183    fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
184        match self {
185            Self::Function(v) => v.into_gcp_provider(),
186            #[cfg(feature = "python")]
187            Self::Python(v) => v.into_gcp_provider(),
188        }
189    }
190
191    fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>> {
192        match self {
193            Self::Function(v) => v.storage_update_options(),
194            #[cfg(feature = "python")]
195            Self::Python(v) => v.storage_update_options(),
196        }
197    }
198}
199
200type CredentialProviderFunctionImpl = Arc<
201    dyn Fn() -> Pin<
202            Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
203        > + Send
204        + Sync,
205>;
206
207/// Wrapper that implements [`IntoCredentialProvider`], [`Debug`], [`PartialEq`], [`Hash`] etc.
208#[derive(Clone)]
209pub struct CredentialProviderFunction(CredentialProviderFunctionImpl);
210
211macro_rules! build_to_object_store_err {
212    ($s:expr) => {{
213        fn to_object_store_err(
214            e: impl std::error::Error + Send + Sync + 'static,
215        ) -> object_store::Error {
216            object_store::Error::Generic {
217                store: $s,
218                source: Box::new(e),
219            }
220        }
221
222        to_object_store_err
223    }};
224}
225
226impl IntoCredentialProvider for CredentialProviderFunction {
227    #[cfg(feature = "aws")]
228    fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
229        #[derive(Debug)]
230        struct S(
231            CredentialProviderFunction,
232            FetchedCredentialsCache<Arc<object_store::aws::AwsCredential>>,
233        );
234
235        #[async_trait]
236        impl object_store::CredentialProvider for S {
237            type Credential = object_store::aws::AwsCredential;
238
239            async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
240                self.1
241                    .get_maybe_update(async {
242                        let (creds, expiry) = self.0.0().await?;
243                        PolarsResult::Ok((creds.unwrap_aws(), expiry))
244                    })
245                    .await
246                    .map_err(build_to_object_store_err!("credential-provider-aws"))
247            }
248        }
249
250        Arc::new(S(
251            self,
252            FetchedCredentialsCache::new(Arc::new(AwsCredential {
253                key_id: String::new(),
254                secret_key: String::new(),
255                token: None,
256            })),
257        ))
258    }
259
260    #[cfg(feature = "azure")]
261    fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
262        #[derive(Debug)]
263        struct S(
264            CredentialProviderFunction,
265            FetchedCredentialsCache<Arc<object_store::azure::AzureCredential>>,
266        );
267
268        #[async_trait]
269        impl object_store::CredentialProvider for S {
270            type Credential = object_store::azure::AzureCredential;
271
272            async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
273                self.1
274                    .get_maybe_update(async {
275                        let (creds, expiry) = self.0.0().await?;
276                        PolarsResult::Ok((creds.unwrap_azure(), expiry))
277                    })
278                    .await
279                    .map_err(build_to_object_store_err!("credential-provider-azure"))
280            }
281        }
282
283        Arc::new(S(
284            self,
285            FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))),
286        ))
287    }
288
289    #[cfg(feature = "gcp")]
290    fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
291        #[derive(Debug)]
292        struct S(
293            CredentialProviderFunction,
294            FetchedCredentialsCache<Arc<object_store::gcp::GcpCredential>>,
295        );
296
297        #[async_trait]
298        impl object_store::CredentialProvider for S {
299            type Credential = object_store::gcp::GcpCredential;
300
301            async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
302                self.1
303                    .get_maybe_update(async {
304                        let (creds, expiry) = self.0.0().await?;
305                        PolarsResult::Ok((creds.unwrap_gcp(), expiry))
306                    })
307                    .await
308                    .map_err(build_to_object_store_err!("credential-provider-gcp"))
309            }
310        }
311
312        Arc::new(S(
313            self,
314            FetchedCredentialsCache::new(Arc::new(GcpCredential {
315                bearer: String::new(),
316            })),
317        ))
318    }
319
320    fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>> {
321        Ok(vec![])
322    }
323}
324
325impl Debug for CredentialProviderFunction {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        write!(
328            f,
329            "credential provider function at 0x{:016x}",
330            self.0.as_ref() as *const _ as *const () as usize
331        )
332    }
333}
334
335impl Eq for CredentialProviderFunction {}
336
337impl PartialEq for CredentialProviderFunction {
338    fn eq(&self, other: &Self) -> bool {
339        Arc::ptr_eq(&self.0, &other.0)
340    }
341}
342
343impl Hash for CredentialProviderFunction {
344    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
345        state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
346    }
347}
348
349#[cfg(feature = "serde")]
350impl<'de> serde::Deserialize<'de> for PlCredentialProvider {
351    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
352    where
353        D: serde::Deserializer<'de>,
354    {
355        #[cfg(feature = "python")]
356        {
357            Ok(Self::Python(PythonCredentialProvider::deserialize(
358                _deserializer,
359            )?))
360        }
361        #[cfg(not(feature = "python"))]
362        {
363            use serde::de::Error;
364            Err(D::Error::custom("cannot deserialize PlCredentialProvider"))
365        }
366    }
367}
368
369#[cfg(feature = "serde")]
370impl serde::Serialize for PlCredentialProvider {
371    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
372    where
373        S: serde::Serializer,
374    {
375        use serde::ser::Error;
376
377        #[cfg(feature = "python")]
378        if let PlCredentialProvider::Python(v) = self {
379            return v.serialize(_serializer);
380        }
381
382        Err(S::Error::custom(format!("cannot serialize {self:?}")))
383    }
384}
385
386#[cfg(feature = "dsl-schema")]
387impl schemars::JsonSchema for PlCredentialProvider {
388    fn schema_name() -> String {
389        "PlCredentialProvider".to_owned()
390    }
391
392    fn schema_id() -> std::borrow::Cow<'static, str> {
393        std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PlCredentialProvider"))
394    }
395
396    fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
397        Vec::<u8>::json_schema(generator)
398    }
399}
400
401/// Avoids calling the credential provider function if we have not yet passed the expiry time.
402#[derive(Debug)]
403struct FetchedCredentialsCache<C>(tokio::sync::Mutex<(C, u64, bool)>);
404
405impl<C: Clone> FetchedCredentialsCache<C> {
406    fn new(init_creds: C) -> Self {
407        Self(tokio::sync::Mutex::new((init_creds, 0, true)))
408    }
409
410    async fn get_maybe_update(
411        &self,
412        // Taking an `impl Future` here allows us to potentially avoid a `Box::pin` allocation from
413        // a `Fn() -> Pin<Box<dyn Future>>` by having it wrapped in an `async { f() }` block. We
414        // will not poll that block if the credentials have not yet expired.
415        update_func: impl Future<Output = PolarsResult<(C, u64)>>,
416    ) -> PolarsResult<C> {
417        let verbose = config::verbose();
418
419        fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String {
420            if last_fetched_expiry == u64::MAX {
421                "expiry = (never expires)".into()
422            } else {
423                format!(
424                    "expiry = {} (in {} seconds)",
425                    last_fetched_expiry,
426                    last_fetched_expiry.saturating_sub(now)
427                )
428            }
429        }
430
431        let mut inner = self.0.lock().await;
432        let (last_fetched_credentials, last_fetched_expiry, log_use_cached) = &mut *inner;
433
434        let current_time = SystemTime::now()
435            .duration_since(UNIX_EPOCH)
436            .unwrap()
437            .as_secs();
438
439        if *last_fetched_expiry <= current_time {
440            if verbose {
441                eprintln!(
442                    "[FetchedCredentialsCache]: \
443                    Call update_func: current_time = {}, \
444                    last_fetched_expiry = {}",
445                    current_time, *last_fetched_expiry
446                )
447            }
448
449            let (credentials, expiry) = update_func.await?;
450
451            *last_fetched_credentials = credentials;
452            *last_fetched_expiry = expiry;
453            *log_use_cached = true;
454
455            if expiry < current_time && expiry != 0 {
456                polars_bail!(
457                    ComputeError:
458                    "credential expiry time {} is older than system time {} \
459                     by {} seconds",
460                    expiry,
461                    current_time,
462                    current_time - expiry
463                )
464            }
465
466            if verbose {
467                eprintln!(
468                    "[FetchedCredentialsCache]: Finish update_func: new {}",
469                    expiry_msg(
470                        *last_fetched_expiry,
471                        SystemTime::now()
472                            .duration_since(UNIX_EPOCH)
473                            .unwrap()
474                            .as_secs()
475                    )
476                )
477            }
478        } else if verbose && *log_use_cached {
479            *log_use_cached = false;
480            let now = SystemTime::now()
481                .duration_since(UNIX_EPOCH)
482                .unwrap()
483                .as_secs();
484            eprintln!(
485                "[FetchedCredentialsCache]: Using cached credentials: \
486                current_time = {}, {}",
487                now,
488                expiry_msg(*last_fetched_expiry, now)
489            )
490        }
491
492        Ok(last_fetched_credentials.clone())
493    }
494}
495
496#[cfg(feature = "python")]
497mod python_impl {
498    use std::hash::Hash;
499    use std::sync::Arc;
500
501    use polars_error::{PolarsError, PolarsResult};
502    use polars_utils::pl_str::PlSmallStr;
503    use polars_utils::python_function::PythonObject;
504    use pyo3::exceptions::PyValueError;
505    use pyo3::pybacked::PyBackedStr;
506    use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
507    use pyo3::{Python, intern};
508
509    use super::IntoCredentialProvider;
510
511    #[derive(Clone, Debug)]
512    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
513    pub enum PythonCredentialProvider {
514        #[cfg_attr(
515            feature = "serde",
516            serde(
517                serialize_with = "PythonObject::serialize_with_pyversion",
518                deserialize_with = "PythonObject::deserialize_with_pyversion"
519            )
520        )]
521        /// Indicates `py_object` is a `CredentialProviderBuilder`.
522        Builder(Arc<PythonObject>),
523        #[cfg_attr(
524            feature = "serde",
525            serde(
526                serialize_with = "PythonObject::serialize_with_pyversion",
527                deserialize_with = "PythonObject::deserialize_with_pyversion"
528            )
529        )]
530        /// Indicates `py_object` is an instantiated credential provider
531        Provider(Arc<PythonObject>),
532    }
533
534    impl PythonCredentialProvider {
535        /// Performs initialization if necessary.
536        ///
537        /// This exists as a separate step that must be called beforehand. This approach is easier
538        /// as the alternative is to refactor the `IntoCredentialProvider` trait to return
539        /// `PolarsResult<Option<T>>` for every single function.
540        pub(super) fn try_into_initialized(
541            self,
542            clear_cached_credentials: bool,
543        ) -> PolarsResult<Option<Self>> {
544            match self {
545                Self::Builder(py_object) => {
546                    let opt_initialized_py_object = Python::with_gil(|py| {
547                        let build_fn =
548                            py_object.getattr(py, intern!(py, "build_credential_provider"))?;
549
550                        let v = build_fn.call1(py, (clear_cached_credentials,))?;
551                        let v = (!v.is_none(py)).then_some(v);
552
553                        pyo3::PyResult::Ok(v)
554                    })?;
555
556                    Ok(opt_initialized_py_object
557                        .map(PythonObject)
558                        .map(Arc::new)
559                        .map(Self::Provider))
560                },
561                Self::Provider(_) => {
562                    // Note: We don't expect to hit here.
563                    Ok(Some(self))
564                },
565            }
566        }
567
568        fn unwrap_as_provider(self) -> Arc<PythonObject> {
569            match self {
570                Self::Builder(_) => panic!(),
571                Self::Provider(v) => v,
572            }
573        }
574
575        pub(crate) fn unwrap_as_provider_ref(&self) -> &Arc<PythonObject> {
576            match self {
577                Self::Builder(_) => panic!(),
578                Self::Provider(v) => v,
579            }
580        }
581
582        pub(super) fn func_addr(&self) -> usize {
583            (match self {
584                Self::Builder(v) => Arc::as_ptr(v),
585                Self::Provider(v) => Arc::as_ptr(v),
586            }) as *const () as usize
587        }
588    }
589
590    impl IntoCredentialProvider for PythonCredentialProvider {
591        #[cfg(feature = "aws")]
592        fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
593            use polars_error::PolarsResult;
594
595            use crate::cloud::credential_provider::{
596                CredentialProviderFunction, ObjectStoreCredential,
597            };
598
599            let func = self.unwrap_as_provider();
600
601            CredentialProviderFunction(Arc::new(move || {
602                let func = func.clone();
603                Box::pin(async move {
604                    let mut credentials = object_store::aws::AwsCredential {
605                        key_id: String::new(),
606                        secret_key: String::new(),
607                        token: None,
608                    };
609
610                    let expiry = Python::with_gil(|py| {
611                        let v = func.0.call0(py)?.into_bound(py);
612                        let (storage_options, expiry) =
613                            v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
614
615                        for (k, v) in storage_options.iter() {
616                            let k = k.extract::<PyBackedStr>()?;
617                            let v = v.extract::<Option<String>>()?;
618
619                            match k.as_ref() {
620                                "aws_access_key_id" => {
621                                    credentials.key_id = v.ok_or_else(|| {
622                                        PyValueError::new_err("aws_access_key_id was None")
623                                    })?;
624                                },
625                                "aws_secret_access_key" => {
626                                    credentials.secret_key = v.ok_or_else(|| {
627                                        PyValueError::new_err("aws_secret_access_key was None")
628                                    })?
629                                },
630                                "aws_session_token" => credentials.token = v,
631                                v => {
632                                    return pyo3::PyResult::Err(PyValueError::new_err(format!(
633                                        "unknown configuration key for aws: {}, \
634                                    valid configuration keys are: \
635                                    {}, {}, {}",
636                                        v,
637                                        "aws_access_key_id",
638                                        "aws_secret_access_key",
639                                        "aws_session_token"
640                                    )));
641                                },
642                            }
643                        }
644
645                        pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
646                    })?;
647
648                    if credentials.key_id.is_empty() {
649                        return Err(PolarsError::ComputeError(
650                            "aws_access_key_id was empty or not given".into(),
651                        ));
652                    }
653
654                    if credentials.secret_key.is_empty() {
655                        return Err(PolarsError::ComputeError(
656                            "aws_secret_access_key was empty or not given".into(),
657                        ));
658                    }
659
660                    PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry))
661                })
662            }))
663            .into_aws_provider()
664        }
665
666        #[cfg(feature = "azure")]
667        fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
668            use object_store::azure::AzureAccessKey;
669            use polars_error::PolarsResult;
670
671            use crate::cloud::credential_provider::{
672                CredentialProviderFunction, ObjectStoreCredential,
673            };
674
675            let func = self.unwrap_as_provider();
676
677            CredentialProviderFunction(Arc::new(move || {
678                let func = func.clone();
679                Box::pin(async move {
680                    let mut credentials = None;
681
682                    static VALID_KEYS_MSG: &str =
683                        "valid configuration keys are: account_key, bearer_token";
684
685                    let expiry = Python::with_gil(|py| {
686                        let v = func.0.call0(py)?.into_bound(py);
687                        let (storage_options, expiry) =
688                            v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
689
690                        for (k, v) in storage_options.iter() {
691                            let k = k.extract::<PyBackedStr>()?;
692                            let v = v.extract::<String>()?;
693
694                            match k.as_ref() {
695                                "account_key" => {
696                                    credentials =
697                                        Some(object_store::azure::AzureCredential::AccessKey(
698                                            AzureAccessKey::try_new(v.as_str()).map_err(|e| {
699                                                PyValueError::new_err(e.to_string())
700                                            })?,
701                                        ))
702                                },
703                                "bearer_token" => {
704                                    credentials =
705                                        Some(object_store::azure::AzureCredential::BearerToken(v))
706                                },
707                                v => {
708                                    return pyo3::PyResult::Err(PyValueError::new_err(format!(
709                                        "unknown configuration key for azure: {v}, {VALID_KEYS_MSG}"
710                                    )));
711                                },
712                            }
713                        }
714
715                        pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
716                    })?;
717
718                    let Some(credentials) = credentials else {
719                        return Err(PolarsError::ComputeError(
720                            format!(
721                                "did not find a valid configuration key for azure, {VALID_KEYS_MSG}"
722                            )
723                            .into(),
724                        ));
725                    };
726
727                    PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry))
728                })
729            }))
730            .into_azure_provider()
731        }
732
733        #[cfg(feature = "gcp")]
734        fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
735            use polars_error::PolarsResult;
736
737            use crate::cloud::credential_provider::{
738                CredentialProviderFunction, ObjectStoreCredential,
739            };
740
741            let func = self.unwrap_as_provider();
742
743            CredentialProviderFunction(Arc::new(move || {
744                let func = func.clone();
745                Box::pin(async move {
746                    let mut credentials = object_store::gcp::GcpCredential {
747                        bearer: String::new(),
748                    };
749
750                    let expiry = Python::with_gil(|py| {
751                        let v = func.0.call0(py)?.into_bound(py);
752                        let (storage_options, expiry) =
753                            v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
754
755                        for (k, v) in storage_options.iter() {
756                            let k = k.extract::<PyBackedStr>()?;
757                            let v = v.extract::<String>()?;
758
759                            match k.as_ref() {
760                                "bearer_token" => credentials.bearer = v,
761                                v => {
762                                    return pyo3::PyResult::Err(PyValueError::new_err(format!(
763                                        "unknown configuration key for gcp: {}, \
764                                    valid configuration keys are: {}",
765                                        v, "bearer_token",
766                                    )));
767                                },
768                            }
769                        }
770
771                        pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
772                    })?;
773
774                    if credentials.bearer.is_empty() {
775                        return Err(PolarsError::ComputeError(
776                            "bearer was empty or not given".into(),
777                        ));
778                    }
779
780                    PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry))
781                })
782            }))
783            .into_gcp_provider()
784        }
785
786        /// # Panics
787        /// Panics if `self` is not an initialized provider.
788        fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>> {
789            let py_object = self.unwrap_as_provider_ref();
790
791            Python::with_gil(|py| {
792                py_object
793                    .getattr(py, "_storage_update_options")
794                    .map_or(Ok(vec![]), |f| {
795                        let v = f.call0(py)?.extract::<pyo3::Bound<'_, PyDict>>(py)?;
796
797                        let mut out = Vec::with_capacity(v.len());
798
799                        for dict_item in v.call_method0("items")?.try_iter()? {
800                            let (key, value) =
801                                dict_item?.extract::<(PyBackedStr, PyBackedStr)>()?;
802
803                            out.push(((&*key).into(), (&*value).into()))
804                        }
805
806                        Ok(out)
807                    })
808            })
809        }
810    }
811
812    // Note: We don't consider `is_builder` for hash/eq - we don't expect the same Arc<PythonObject>
813    // to be referenced as both true and false from the `is_builder` field.
814
815    impl Eq for PythonCredentialProvider {}
816
817    impl PartialEq for PythonCredentialProvider {
818        fn eq(&self, other: &Self) -> bool {
819            self.func_addr() == other.func_addr()
820        }
821    }
822
823    impl Hash for PythonCredentialProvider {
824        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
825            // # Safety
826            // * Inner is an `Arc`
827            // * Visibility is limited to super
828            // * No code in `mod python_impl` or `super` mutates the Arc inner.
829            state.write_usize(self.func_addr())
830        }
831    }
832}
833
834#[cfg(test)]
835mod tests {
836    #[cfg(feature = "serde")]
837    #[allow(clippy::redundant_pattern_matching)]
838    #[test]
839    fn test_serde() {
840        use super::*;
841
842        assert!(matches!(
843            serde_json::to_string(&Some(PlCredentialProvider::from_func(|| {
844                Box::pin(core::future::ready(PolarsResult::Ok((
845                    ObjectStoreCredential::None,
846                    0,
847                ))))
848            }))),
849            Err(_)
850        ));
851
852        assert!(matches!(
853            serde_json::to_string(&Option::<PlCredentialProvider>::None),
854            Ok(String { .. })
855        ));
856
857        assert!(matches!(
858            serde_json::from_str::<Option<PlCredentialProvider>>(
859                serde_json::to_string(&Option::<PlCredentialProvider>::None)
860                    .unwrap()
861                    .as_str()
862            ),
863            Ok(None)
864        ));
865    }
866}