polars_time/
round.rs

1use arrow::legacy::time_zone::Tz;
2use arrow::temporal_conversions::MILLISECONDS_IN_DAY;
3use polars_core::prelude::arity::broadcast_try_binary_elementwise;
4use polars_core::prelude::*;
5use polars_utils::cache::LruCache;
6
7use crate::prelude::*;
8use crate::truncate::fast_truncate;
9
10#[inline(always)]
11fn fast_round(t: i64, every: i64) -> i64 {
12    fast_truncate(t + every / 2, every)
13}
14
15pub trait PolarsRound {
16    fn round(&self, every: &StringChunked, tz: Option<&Tz>) -> PolarsResult<Self>
17    where
18        Self: Sized;
19}
20
21impl PolarsRound for DatetimeChunked {
22    fn round(&self, every: &StringChunked, tz: Option<&Tz>) -> PolarsResult<Self> {
23        let time_zone = self.time_zone();
24        let offset = Duration::new(0);
25
26        // Let's check if we can use a fastpath...
27        if every.len() == 1 {
28            if let Some(every) = every.get(0) {
29                let every_parsed = Duration::try_parse(every)?;
30                if every_parsed.negative {
31                    polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
32                }
33                if (time_zone.is_none() || time_zone == &Some(TimeZone::UTC))
34                    && (every_parsed.months() == 0 && every_parsed.weeks() == 0)
35                {
36                    // ... yes we can! Weeks, months, and time zones require extra logic.
37                    // But in this simple case, it's just simple integer arithmetic.
38                    let every = match self.time_unit() {
39                        TimeUnit::Milliseconds => every_parsed.duration_ms(),
40                        TimeUnit::Microseconds => every_parsed.duration_us(),
41                        TimeUnit::Nanoseconds => every_parsed.duration_ns(),
42                    };
43                    return Ok(self
44                        .physical()
45                        .apply_values(|t| fast_round(t, every))
46                        .into_datetime(self.time_unit(), time_zone.clone()));
47                } else {
48                    let w = Window::new(every_parsed, every_parsed, offset);
49                    let out = match self.time_unit() {
50                        TimeUnit::Milliseconds => self
51                            .physical()
52                            .try_apply_nonnull_values_generic(|t| w.round_ms(t, tz)),
53                        TimeUnit::Microseconds => self
54                            .physical()
55                            .try_apply_nonnull_values_generic(|t| w.round_us(t, tz)),
56                        TimeUnit::Nanoseconds => self
57                            .physical()
58                            .try_apply_nonnull_values_generic(|t| w.round_ns(t, tz)),
59                    };
60                    return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone()));
61                }
62            } else {
63                return Ok(Int64Chunked::full_null(self.name().clone(), self.len())
64                    .into_datetime(self.time_unit(), self.time_zone().clone()));
65            }
66        }
67
68        polars_ensure!(
69            self.len() == every.len() || self.len() == 1,
70            length_mismatch = "dt.round",
71            self.len(),
72            every.len()
73        );
74
75        // A sqrt(n) cache is not too small, not too large.
76        let mut duration_cache = LruCache::with_capacity((every.len() as f64).sqrt() as usize);
77
78        let func = match self.time_unit() {
79            TimeUnit::Nanoseconds => Window::round_ns,
80            TimeUnit::Microseconds => Window::round_us,
81            TimeUnit::Milliseconds => Window::round_ms,
82        };
83
84        let out = broadcast_try_binary_elementwise(
85            self.physical(),
86            every,
87            |opt_timestamp, opt_every| match (opt_timestamp, opt_every) {
88                (Some(timestamp), Some(every)) => {
89                    let every = *duration_cache.get_or_insert_with(every, Duration::parse);
90
91                    if every.negative {
92                        polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
93                    }
94
95                    let w = Window::new(every, every, offset);
96                    func(&w, timestamp, tz).map(Some)
97                },
98                _ => Ok(None),
99            },
100        );
101        Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone()))
102    }
103}
104
105impl PolarsRound for DateChunked {
106    fn round(&self, every: &StringChunked, _tz: Option<&Tz>) -> PolarsResult<Self> {
107        let offset = Duration::new(0);
108        let out = match every.len() {
109            1 => {
110                if let Some(every) = every.get(0) {
111                    let every = Duration::try_parse(every)?;
112                    if every.negative {
113                        polars_bail!(ComputeError: "cannot round a Date to a negative duration")
114                    }
115                    let w = Window::new(every, every, offset);
116                    self.physical().try_apply_nonnull_values_generic(|t| {
117                        Ok(
118                            (w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)?
119                                / MILLISECONDS_IN_DAY) as i32,
120                        )
121                    })
122                } else {
123                    Ok(Int32Chunked::full_null(self.name().clone(), self.len()))
124                }
125            },
126            _ => {
127                polars_ensure!(
128                    self.len() == every.len() || self.len() == 1,
129                    length_mismatch = "dt.round",
130                    self.len(),
131                    every.len()
132                );
133                broadcast_try_binary_elementwise(self.physical(), every, |opt_t, opt_every| {
134                    // A sqrt(n) cache is not too small, not too large.
135                    let mut duration_cache =
136                        LruCache::with_capacity((every.len() as f64).sqrt() as usize);
137                    match (opt_t, opt_every) {
138                        (Some(t), Some(every)) => {
139                            let every = *duration_cache.get_or_insert_with(every, Duration::parse);
140
141                            if every.negative {
142                                polars_bail!(ComputeError: "cannot round a Date to a negative duration")
143                            }
144
145                            let w = Window::new(every, every, offset);
146                            Ok(Some(
147                                (w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)?
148                                    / MILLISECONDS_IN_DAY) as i32,
149                            ))
150                        },
151                        _ => Ok(None),
152                    }
153                })
154            },
155        };
156        Ok(out?.into_date())
157    }
158}