polars_utils/async_utils/
error_capture.rs

1use std::any::Any;
2use std::panic::AssertUnwindSafe;
3
4use futures::FutureExt;
5
6/// Utility to capture errors and propagate them to an associated [`ErrorHandle`].
7pub struct ErrorCapture<ErrorT> {
8    tx: tokio::sync::mpsc::Sender<ErrorMessage<ErrorT>>,
9}
10
11impl<ErrorT> Clone for ErrorCapture<ErrorT> {
12    fn clone(&self) -> Self {
13        Self {
14            tx: self.tx.clone(),
15        }
16    }
17}
18
19impl<ErrorT> ErrorCapture<ErrorT> {
20    pub fn new() -> (Self, ErrorHandle<ErrorT>) {
21        let (tx, rx) = tokio::sync::mpsc::channel(1);
22        (Self { tx }, ErrorHandle { rx })
23    }
24
25    /// Wraps a future such that its error result is sent to the associated [`ErrorHandle`].
26    pub async fn wrap_future<F, O>(self, fut: F)
27    where
28        F: Future<Output = Result<O, ErrorT>>,
29    {
30        let err: Result<(), tokio::sync::mpsc::error::TrySendError<ErrorMessage<ErrorT>>> =
31            match AssertUnwindSafe(fut).catch_unwind().await {
32                Ok(Ok(_)) => return,
33                Ok(Err(err)) => self.tx.try_send(ErrorMessage::Error(err)),
34                Err(panic) => self.tx.try_send(ErrorMessage::Panic(panic)),
35            };
36        drop(err);
37    }
38}
39
40enum ErrorMessage<ErrorT> {
41    Error(ErrorT),
42    Panic(Box<dyn Any + Send + 'static>),
43}
44
45/// Handle to await the completion of multiple tasks. Propagates error results
46/// and resumes unwinds when joined.
47pub struct ErrorHandle<ErrorT> {
48    rx: tokio::sync::mpsc::Receiver<ErrorMessage<ErrorT>>,
49}
50
51impl<ErrorT> ErrorHandle<ErrorT> {
52    pub fn has_errored(&self) -> bool {
53        !self.rx.is_empty()
54    }
55
56    /// Block until either an error is received, or all [`ErrorCapture`]s associated with this
57    /// handle are dropped (i.e. successful completion of all wrapped futures).
58    ///
59    /// # Panics
60    /// If a panic is received, this will resume unwinding.
61    pub async fn join(self) -> Result<(), ErrorT> {
62        let ErrorHandle { mut rx } = self;
63
64        match rx.recv().await {
65            None => Ok(()),
66            Some(ErrorMessage::Error(e)) => Err(e),
67            Some(ErrorMessage::Panic(panic)) => std::panic::resume_unwind(panic),
68        }
69    }
70}