polars_core/frame/group_by/aggregations/
boolean.rs

1use arrow::bitmap::bitmask::BitMask;
2
3use super::*;
4use crate::chunked_array::cast::CastOptions;
5
6pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
7where
8    F: Fn((IdxSize, &IdxVec)) -> Option<bool> + Send + Sync,
9{
10    let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
11    ca.into_series()
12}
13
14pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
15where
16    F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
17{
18    let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
19    ca.into_series()
20}
21
22#[cfg(feature = "bitwise")]
23impl BooleanChunked {
24    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> BooleanChunked {
25        self.agg_all(groups, true)
26    }
27
28    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> BooleanChunked {
29        self.agg_any(groups, true)
30    }
31
32    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> BooleanChunked {
33        self.bool_agg(
34            groups,
35            true,
36            |values, idxs| {
37                idxs.iter()
38                    .map(|i| {
39                        <IdxSize as From<bool>>::from(unsafe {
40                            values.get_bit_unchecked(*i as usize)
41                        })
42                    })
43                    .sum::<IdxSize>()
44                    % 2
45                    == 1
46            },
47            |values, validity, idxs| {
48                idxs.iter()
49                    .map(|i| {
50                        <IdxSize as From<bool>>::from(unsafe {
51                            validity.get_bit_unchecked(*i as usize)
52                                & values.get_bit_unchecked(*i as usize)
53                        })
54                    })
55                    .sum::<IdxSize>()
56                    % 2
57                    == 0
58            },
59            |_, _, _| unreachable!(),
60            |values, start, length| {
61                unsafe { values.sliced_unchecked(start as usize, length as usize) }.set_bits() % 2
62                    == 1
63            },
64            |values, validity, start, length| {
65                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
66                let validity =
67                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
68                values.num_intersections_with(validity) % 2 == 1
69            },
70            |_, _, _, _| unreachable!(),
71        )
72    }
73}
74
75impl BooleanChunked {
76    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
77        // faster paths
78        match self.is_sorted_flag() {
79            IsSorted::Ascending => return self.clone().into_series().agg_first_non_null(groups),
80            IsSorted::Descending => return self.clone().into_series().agg_last_non_null(groups),
81            _ => {},
82        }
83        let ca_self = self.rechunk();
84        let arr = ca_self.downcast_iter().next().unwrap();
85        let no_nulls = arr.null_count() == 0;
86        match groups {
87            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
88                debug_assert!(idx.len() <= self.len());
89                if idx.is_empty() {
90                    None
91                } else if idx.len() == 1 {
92                    arr.get(first as usize)
93                } else if no_nulls {
94                    take_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
95                } else {
96                    take_min_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
97                }
98            }),
99            GroupsType::Slice {
100                groups: groups_slice,
101                ..
102            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
103                debug_assert!(len <= self.len() as IdxSize);
104                match len {
105                    0 => None,
106                    1 => self.get(first as usize),
107                    _ => {
108                        let arr_group = _slice_from_offsets(self, first, len);
109                        arr_group.min()
110                    },
111                }
112            }),
113        }
114    }
115    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
116        // faster paths
117        match (self.is_sorted_flag(), self.null_count()) {
118            (IsSorted::Ascending, 0) => {
119                return self.clone().into_series().agg_last(groups);
120            },
121            (IsSorted::Descending, 0) => {
122                return self.clone().into_series().agg_first(groups);
123            },
124            _ => {},
125        }
126
127        let ca_self = self.rechunk();
128        let arr = ca_self.downcast_iter().next().unwrap();
129        let no_nulls = arr.null_count() == 0;
130        match groups {
131            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
132                debug_assert!(idx.len() <= self.len());
133                if idx.is_empty() {
134                    None
135                } else if idx.len() == 1 {
136                    self.get(first as usize)
137                } else if no_nulls {
138                    take_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
139                } else {
140                    take_max_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
141                }
142            }),
143            GroupsType::Slice {
144                groups: groups_slice,
145                ..
146            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
147                debug_assert!(len <= self.len() as IdxSize);
148                match len {
149                    0 => None,
150                    1 => self.get(first as usize),
151                    _ => {
152                        let arr_group = _slice_from_offsets(self, first, len);
153                        arr_group.max()
154                    },
155                }
156            }),
157        }
158    }
159
160    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
161        self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)
162            .unwrap()
163            .agg_sum(groups)
164    }
165
166    /// # Safety
167    ///
168    /// Groups should be in correct.
169    #[expect(clippy::too_many_arguments)]
170    unsafe fn bool_agg(
171        &self,
172        groups: &GroupsType,
173        ignore_nulls: bool,
174
175        idx_no_valid: impl Fn(BitMask, &[IdxSize]) -> bool + Send + Sync,
176        idx_validity: impl Fn(BitMask, BitMask, &[IdxSize]) -> bool + Send + Sync,
177        idx_kleene: impl Fn(BitMask, BitMask, &[IdxSize]) -> Option<bool> + Send + Sync,
178
179        slice_no_valid: impl Fn(BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
180        slice_validity: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
181        slice_kleene: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> Option<bool> + Send + Sync,
182    ) -> BooleanChunked {
183        let name = self.name().clone();
184        let values = self.rechunk();
185        let values = values.downcast_as_array();
186
187        let ca: BooleanChunked = POOL.install(|| {
188            let validity = values
189                .validity()
190                .filter(|v| v.unset_bits() > 0)
191                .map(BitMask::from_bitmap);
192            let values = BitMask::from_bitmap(values.values());
193
194            if !ignore_nulls && let Some(validity) = validity {
195                match groups {
196                    GroupsType::Idx(idx) => idx
197                        .into_par_iter()
198                        .map(|(_, idx)| idx_kleene(values, validity, idx))
199                        .collect(),
200                    GroupsType::Slice {
201                        groups,
202                        overlapping: _,
203                        monotonic: _,
204                    } => groups
205                        .into_par_iter()
206                        .map(|[start, length]| slice_kleene(values, validity, *start, *length))
207                        .collect(),
208                }
209            } else {
210                match groups {
211                    GroupsType::Idx(idx) => match validity {
212                        None => idx
213                            .into_par_iter()
214                            .map(|(_, idx)| idx_no_valid(values, idx))
215                            .collect(),
216                        Some(validity) => idx
217                            .into_par_iter()
218                            .map(|(_, idx)| idx_validity(values, validity, idx))
219                            .collect(),
220                    },
221                    GroupsType::Slice {
222                        groups,
223                        overlapping: _,
224                        monotonic: _,
225                    } => match validity {
226                        None => groups
227                            .into_par_iter()
228                            .map(|[start, length]| slice_no_valid(values, *start, *length))
229                            .collect(),
230                        Some(validity) => groups
231                            .into_par_iter()
232                            .map(|[start, length]| {
233                                slice_validity(values, validity, *start, *length)
234                            })
235                            .collect(),
236                    },
237                }
238            }
239        });
240        ca.with_name(name)
241    }
242
243    /// # Safety
244    ///
245    /// Groups should be in correct.
246    pub unsafe fn agg_any(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
247        self.bool_agg(
248            groups,
249            ignore_nulls,
250            |values, idxs| {
251                idxs.iter()
252                    .any(|i| unsafe { values.get_bit_unchecked(*i as usize) })
253            },
254            |values, validity, idxs| {
255                idxs.iter().any(|i| unsafe {
256                    validity.get_bit_unchecked(*i as usize) & values.get_bit_unchecked(*i as usize)
257                })
258            },
259            |values, validity, idxs| {
260                let mut saw_null = false;
261                for i in idxs.iter() {
262                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
263                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
264
265                    if is_valid & is_true {
266                        return Some(true);
267                    }
268                    saw_null |= !is_valid;
269                }
270                (!saw_null).then_some(false)
271            },
272            |values, start, length| {
273                unsafe { values.sliced_unchecked(start as usize, length as usize) }.leading_zeros()
274                    < length as usize
275            },
276            |values, validity, start, length| {
277                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
278                let validity =
279                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
280                values.intersects_with(validity)
281            },
282            |values, validity, start, length| {
283                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
284                let validity =
285                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
286
287                if values.intersects_with(validity) {
288                    Some(true)
289                } else if validity.unset_bits() == 0 {
290                    Some(false)
291                } else {
292                    None
293                }
294            },
295        )
296    }
297
298    /// # Safety
299    ///
300    /// Groups should be in correct.
301    pub unsafe fn agg_all(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
302        self.bool_agg(
303            groups,
304            ignore_nulls,
305            |values, idxs| {
306                idxs.iter()
307                    .all(|i| unsafe { values.get_bit_unchecked(*i as usize) })
308            },
309            |values, validity, idxs| {
310                idxs.iter().all(|i| unsafe {
311                    !validity.get_bit_unchecked(*i as usize) | values.get_bit_unchecked(*i as usize)
312                })
313            },
314            |values, validity, idxs| {
315                let mut saw_null = false;
316                for i in idxs.iter() {
317                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
318                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
319
320                    if is_valid & !is_true {
321                        return Some(false);
322                    }
323                    saw_null |= !is_valid;
324                }
325                (!saw_null).then_some(true)
326            },
327            |values, start, length| {
328                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
329                values.unset_bits() == 0
330            },
331            |values, validity, start, length| {
332                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
333                let validity =
334                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
335                values.num_intersections_with(validity) == validity.set_bits()
336            },
337            |values, validity, start, length| {
338                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
339                let validity =
340                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
341
342                let num_non_nulls = validity.set_bits();
343
344                if values.num_intersections_with(validity) < num_non_nulls {
345                    Some(false)
346                } else if num_non_nulls < values.len() {
347                    None
348                } else {
349                    Some(true)
350                }
351            },
352        )
353    }
354}