polars_core/frame/group_by/aggregations/
boolean.rs

1use arrow::bitmap::bitmask::BitMask;
2
3use super::*;
4use crate::chunked_array::cast::CastOptions;
5
6pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
7where
8    F: Fn((IdxSize, &IdxVec)) -> Option<bool> + Send + Sync,
9{
10    let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
11    ca.into_series()
12}
13
14pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
15where
16    F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
17{
18    let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
19    ca.into_series()
20}
21
22#[cfg(feature = "bitwise")]
23impl BooleanChunked {
24    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> BooleanChunked {
25        self.agg_all(groups, true)
26    }
27
28    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> BooleanChunked {
29        self.agg_any(groups, true)
30    }
31
32    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> BooleanChunked {
33        self.bool_agg(
34            groups,
35            true,
36            |values, idxs| {
37                idxs.iter()
38                    .map(|i| {
39                        <IdxSize as From<bool>>::from(unsafe {
40                            values.get_bit_unchecked(*i as usize)
41                        })
42                    })
43                    .sum::<IdxSize>()
44                    % 2
45                    == 1
46            },
47            |values, validity, idxs| {
48                idxs.iter()
49                    .map(|i| {
50                        <IdxSize as From<bool>>::from(unsafe {
51                            validity.get_bit_unchecked(*i as usize)
52                                & values.get_bit_unchecked(*i as usize)
53                        })
54                    })
55                    .sum::<IdxSize>()
56                    % 2
57                    == 0
58            },
59            |_, _, _| unreachable!(),
60            |values, start, length| {
61                unsafe { values.sliced_unchecked(start as usize, length as usize) }.set_bits() % 2
62                    == 1
63            },
64            |values, validity, start, length| {
65                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
66                let validity =
67                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
68                values.num_intersections_with(validity) % 2 == 1
69            },
70            |_, _, _, _| unreachable!(),
71        )
72    }
73}
74
75impl BooleanChunked {
76    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
77        // faster paths
78        match self.is_sorted_flag() {
79            IsSorted::Ascending => return self.clone().into_series().agg_first_non_null(groups),
80            IsSorted::Descending => return self.clone().into_series().agg_last_non_null(groups),
81            _ => {},
82        }
83        let ca_self = self.rechunk();
84        let arr = ca_self.downcast_iter().next().unwrap();
85        let no_nulls = arr.null_count() == 0;
86        match groups {
87            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
88                debug_assert!(idx.len() <= self.len());
89                if idx.is_empty() {
90                    None
91                } else if idx.len() == 1 {
92                    arr.get(first as usize)
93                } else if no_nulls {
94                    take_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
95                } else {
96                    take_min_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
97                }
98            }),
99            GroupsType::Slice {
100                groups: groups_slice,
101                ..
102            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
103                debug_assert!(len <= self.len() as IdxSize);
104                match len {
105                    0 => None,
106                    1 => self.get(first as usize),
107                    _ => {
108                        let arr_group = _slice_from_offsets(self, first, len);
109                        arr_group.min()
110                    },
111                }
112            }),
113        }
114    }
115    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
116        // faster paths
117        match (self.is_sorted_flag(), self.null_count()) {
118            (IsSorted::Ascending, 0) => {
119                return self.clone().into_series().agg_last(groups);
120            },
121            (IsSorted::Descending, 0) => {
122                return self.clone().into_series().agg_first(groups);
123            },
124            _ => {},
125        }
126
127        let ca_self = self.rechunk();
128        let arr = ca_self.downcast_iter().next().unwrap();
129        let no_nulls = arr.null_count() == 0;
130        match groups {
131            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
132                debug_assert!(idx.len() <= self.len());
133                if idx.is_empty() {
134                    None
135                } else if idx.len() == 1 {
136                    self.get(first as usize)
137                } else if no_nulls {
138                    take_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
139                } else {
140                    take_max_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
141                }
142            }),
143            GroupsType::Slice {
144                groups: groups_slice,
145                ..
146            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
147                debug_assert!(len <= self.len() as IdxSize);
148                match len {
149                    0 => None,
150                    1 => self.get(first as usize),
151                    _ => {
152                        let arr_group = _slice_from_offsets(self, first, len);
153                        arr_group.max()
154                    },
155                }
156            }),
157        }
158    }
159
160    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
161        self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)
162            .unwrap()
163            .agg_sum(groups)
164    }
165
166    /// # Safety
167    ///
168    /// Groups should be in correct.
169    #[expect(clippy::too_many_arguments)]
170    unsafe fn bool_agg(
171        &self,
172        groups: &GroupsType,
173        ignore_nulls: bool,
174
175        idx_no_valid: impl Fn(BitMask, &[IdxSize]) -> bool + Send + Sync,
176        idx_validity: impl Fn(BitMask, BitMask, &[IdxSize]) -> bool + Send + Sync,
177        idx_kleene: impl Fn(BitMask, BitMask, &[IdxSize]) -> Option<bool> + Send + Sync,
178
179        slice_no_valid: impl Fn(BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
180        slice_validity: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
181        slice_kleene: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> Option<bool> + Send + Sync,
182    ) -> BooleanChunked {
183        let name = self.name().clone();
184        let values = self.rechunk();
185        let values = values.downcast_as_array();
186
187        let ca: BooleanChunked = POOL.install(|| {
188            let validity = values
189                .validity()
190                .filter(|v| v.unset_bits() > 0)
191                .map(BitMask::from_bitmap);
192            let values = BitMask::from_bitmap(values.values());
193
194            if !ignore_nulls && let Some(validity) = validity {
195                match groups {
196                    GroupsType::Idx(idx) => idx
197                        .into_par_iter()
198                        .map(|(_, idx)| idx_kleene(values, validity, idx))
199                        .collect(),
200                    GroupsType::Slice {
201                        groups,
202                        overlapping: _,
203                    } => groups
204                        .into_par_iter()
205                        .map(|[start, length]| slice_kleene(values, validity, *start, *length))
206                        .collect(),
207                }
208            } else {
209                match groups {
210                    GroupsType::Idx(idx) => match validity {
211                        None => idx
212                            .into_par_iter()
213                            .map(|(_, idx)| idx_no_valid(values, idx))
214                            .collect(),
215                        Some(validity) => idx
216                            .into_par_iter()
217                            .map(|(_, idx)| idx_validity(values, validity, idx))
218                            .collect(),
219                    },
220                    GroupsType::Slice {
221                        groups,
222                        overlapping: _,
223                    } => match validity {
224                        None => groups
225                            .into_par_iter()
226                            .map(|[start, length]| slice_no_valid(values, *start, *length))
227                            .collect(),
228                        Some(validity) => groups
229                            .into_par_iter()
230                            .map(|[start, length]| {
231                                slice_validity(values, validity, *start, *length)
232                            })
233                            .collect(),
234                    },
235                }
236            }
237        });
238        ca.with_name(name)
239    }
240
241    /// # Safety
242    ///
243    /// Groups should be in correct.
244    pub unsafe fn agg_any(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
245        self.bool_agg(
246            groups,
247            ignore_nulls,
248            |values, idxs| {
249                idxs.iter()
250                    .any(|i| unsafe { values.get_bit_unchecked(*i as usize) })
251            },
252            |values, validity, idxs| {
253                idxs.iter().any(|i| unsafe {
254                    validity.get_bit_unchecked(*i as usize) & values.get_bit_unchecked(*i as usize)
255                })
256            },
257            |values, validity, idxs| {
258                let mut saw_null = false;
259                for i in idxs.iter() {
260                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
261                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
262
263                    if is_valid & is_true {
264                        return Some(true);
265                    }
266                    saw_null |= !is_valid;
267                }
268                (!saw_null).then_some(false)
269            },
270            |values, start, length| {
271                unsafe { values.sliced_unchecked(start as usize, length as usize) }.leading_zeros()
272                    < length as usize
273            },
274            |values, validity, start, length| {
275                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
276                let validity =
277                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
278                values.intersects_with(validity)
279            },
280            |values, validity, start, length| {
281                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
282                let validity =
283                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
284
285                if values.intersects_with(validity) {
286                    Some(true)
287                } else if validity.unset_bits() == 0 {
288                    Some(false)
289                } else {
290                    None
291                }
292            },
293        )
294    }
295
296    /// # Safety
297    ///
298    /// Groups should be in correct.
299    pub unsafe fn agg_all(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
300        self.bool_agg(
301            groups,
302            ignore_nulls,
303            |values, idxs| {
304                idxs.iter()
305                    .all(|i| unsafe { values.get_bit_unchecked(*i as usize) })
306            },
307            |values, validity, idxs| {
308                idxs.iter().all(|i| unsafe {
309                    !validity.get_bit_unchecked(*i as usize) | values.get_bit_unchecked(*i as usize)
310                })
311            },
312            |values, validity, idxs| {
313                let mut saw_null = false;
314                for i in idxs.iter() {
315                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
316                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
317
318                    if is_valid & !is_true {
319                        return Some(false);
320                    }
321                    saw_null |= !is_valid;
322                }
323                (!saw_null).then_some(true)
324            },
325            |values, start, length| {
326                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
327                values.unset_bits() == 0
328            },
329            |values, validity, start, length| {
330                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
331                let validity =
332                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
333                values.num_intersections_with(validity) == validity.set_bits()
334            },
335            |values, validity, start, length| {
336                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
337                let validity =
338                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
339
340                let num_non_nulls = validity.set_bits();
341
342                if values.num_intersections_with(validity) < num_non_nulls {
343                    Some(false)
344                } else if num_non_nulls < values.len() {
345                    None
346                } else {
347                    Some(true)
348                }
349            },
350        )
351    }
352}