polars_core/frame/group_by/aggregations/
boolean.rs

1use arrow::bitmap::bitmask::BitMask;
2
3use super::*;
4use crate::chunked_array::cast::CastOptions;
5
6pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
7where
8    F: Fn((IdxSize, &IdxVec)) -> Option<bool> + Send + Sync,
9{
10    let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
11    ca.into_series()
12}
13
14pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
15where
16    F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
17{
18    let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
19    ca.into_series()
20}
21
22#[cfg(feature = "bitwise")]
23impl BooleanChunked {
24    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> BooleanChunked {
25        self.agg_all(groups, true)
26    }
27
28    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> BooleanChunked {
29        self.agg_any(groups, true)
30    }
31
32    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> BooleanChunked {
33        self.bool_agg(
34            groups,
35            true,
36            |values, idxs| {
37                idxs.iter()
38                    .map(|i| {
39                        <IdxSize as From<bool>>::from(unsafe {
40                            values.get_bit_unchecked(*i as usize)
41                        })
42                    })
43                    .sum::<IdxSize>()
44                    % 2
45                    == 1
46            },
47            |values, validity, idxs| {
48                idxs.iter()
49                    .map(|i| {
50                        <IdxSize as From<bool>>::from(unsafe {
51                            validity.get_bit_unchecked(*i as usize)
52                                & values.get_bit_unchecked(*i as usize)
53                        })
54                    })
55                    .sum::<IdxSize>()
56                    % 2
57                    == 0
58            },
59            |_, _, _| unreachable!(),
60            |values, start, length| {
61                unsafe { values.sliced_unchecked(start as usize, length as usize) }.set_bits() % 2
62                    == 1
63            },
64            |values, validity, start, length| {
65                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
66                let validity =
67                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
68                values.num_intersections_with(validity) % 2 == 1
69            },
70            |_, _, _, _| unreachable!(),
71        )
72    }
73}
74
75impl BooleanChunked {
76    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
77        // faster paths
78        if groups.is_sorted_flag() {
79            match self.is_sorted_flag() {
80                IsSorted::Ascending => {
81                    return self.clone().into_series().agg_first_non_null(groups);
82                },
83                IsSorted::Descending => {
84                    return self.clone().into_series().agg_last_non_null(groups);
85                },
86                _ => {},
87            }
88        }
89        let ca_self = self.rechunk();
90        let arr = ca_self.downcast_iter().next().unwrap();
91        let no_nulls = arr.null_count() == 0;
92        match groups {
93            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
94                debug_assert!(idx.len() <= self.len());
95                if idx.is_empty() {
96                    None
97                } else if idx.len() == 1 {
98                    arr.get(first as usize)
99                } else if no_nulls {
100                    take_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
101                } else {
102                    take_min_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
103                }
104            }),
105            GroupsType::Slice {
106                groups: groups_slice,
107                ..
108            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
109                debug_assert!(len <= self.len() as IdxSize);
110                match len {
111                    0 => None,
112                    1 => self.get(first as usize),
113                    _ => {
114                        let arr_group = _slice_from_offsets(self, first, len);
115                        arr_group.min()
116                    },
117                }
118            }),
119        }
120    }
121    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
122        // faster paths
123        if groups.is_sorted_flag() {
124            match self.is_sorted_flag() {
125                IsSorted::Ascending => return self.clone().into_series().agg_last_non_null(groups),
126                IsSorted::Descending => {
127                    return self.clone().into_series().agg_first_non_null(groups);
128                },
129                _ => {},
130            }
131        }
132
133        let ca_self = self.rechunk();
134        let arr = ca_self.downcast_iter().next().unwrap();
135        let no_nulls = arr.null_count() == 0;
136        match groups {
137            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
138                debug_assert!(idx.len() <= self.len());
139                if idx.is_empty() {
140                    None
141                } else if idx.len() == 1 {
142                    self.get(first as usize)
143                } else if no_nulls {
144                    take_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
145                } else {
146                    take_max_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
147                }
148            }),
149            GroupsType::Slice {
150                groups: groups_slice,
151                ..
152            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
153                debug_assert!(len <= self.len() as IdxSize);
154                match len {
155                    0 => None,
156                    1 => self.get(first as usize),
157                    _ => {
158                        let arr_group = _slice_from_offsets(self, first, len);
159                        arr_group.max()
160                    },
161                }
162            }),
163        }
164    }
165
166    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
167        self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)
168            .unwrap()
169            .agg_sum(groups)
170    }
171
172    /// # Safety
173    ///
174    /// Groups should be in correct.
175    #[expect(clippy::too_many_arguments)]
176    unsafe fn bool_agg(
177        &self,
178        groups: &GroupsType,
179        ignore_nulls: bool,
180
181        idx_no_valid: impl Fn(BitMask, &[IdxSize]) -> bool + Send + Sync,
182        idx_validity: impl Fn(BitMask, BitMask, &[IdxSize]) -> bool + Send + Sync,
183        idx_kleene: impl Fn(BitMask, BitMask, &[IdxSize]) -> Option<bool> + Send + Sync,
184
185        slice_no_valid: impl Fn(BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
186        slice_validity: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
187        slice_kleene: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> Option<bool> + Send + Sync,
188    ) -> BooleanChunked {
189        let name = self.name().clone();
190        let values = self.rechunk();
191        let values = values.downcast_as_array();
192
193        let ca: BooleanChunked = POOL.install(|| {
194            let validity = values
195                .validity()
196                .filter(|v| v.unset_bits() > 0)
197                .map(BitMask::from_bitmap);
198            let values = BitMask::from_bitmap(values.values());
199
200            if !ignore_nulls && let Some(validity) = validity {
201                match groups {
202                    GroupsType::Idx(idx) => idx
203                        .into_par_iter()
204                        .map(|(_, idx)| idx_kleene(values, validity, idx))
205                        .collect(),
206                    GroupsType::Slice {
207                        groups,
208                        overlapping: _,
209                        monotonic: _,
210                    } => groups
211                        .into_par_iter()
212                        .map(|[start, length]| slice_kleene(values, validity, *start, *length))
213                        .collect(),
214                }
215            } else {
216                match groups {
217                    GroupsType::Idx(idx) => match validity {
218                        None => idx
219                            .into_par_iter()
220                            .map(|(_, idx)| idx_no_valid(values, idx))
221                            .collect(),
222                        Some(validity) => idx
223                            .into_par_iter()
224                            .map(|(_, idx)| idx_validity(values, validity, idx))
225                            .collect(),
226                    },
227                    GroupsType::Slice {
228                        groups,
229                        overlapping: _,
230                        monotonic: _,
231                    } => match validity {
232                        None => groups
233                            .into_par_iter()
234                            .map(|[start, length]| slice_no_valid(values, *start, *length))
235                            .collect(),
236                        Some(validity) => groups
237                            .into_par_iter()
238                            .map(|[start, length]| {
239                                slice_validity(values, validity, *start, *length)
240                            })
241                            .collect(),
242                    },
243                }
244            }
245        });
246        ca.with_name(name)
247    }
248
249    /// # Safety
250    ///
251    /// Groups should be in correct.
252    pub unsafe fn agg_any(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
253        self.bool_agg(
254            groups,
255            ignore_nulls,
256            |values, idxs| {
257                idxs.iter()
258                    .any(|i| unsafe { values.get_bit_unchecked(*i as usize) })
259            },
260            |values, validity, idxs| {
261                idxs.iter().any(|i| unsafe {
262                    validity.get_bit_unchecked(*i as usize) & values.get_bit_unchecked(*i as usize)
263                })
264            },
265            |values, validity, idxs| {
266                let mut saw_null = false;
267                for i in idxs.iter() {
268                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
269                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
270
271                    if is_valid & is_true {
272                        return Some(true);
273                    }
274                    saw_null |= !is_valid;
275                }
276                (!saw_null).then_some(false)
277            },
278            |values, start, length| {
279                unsafe { values.sliced_unchecked(start as usize, length as usize) }.leading_zeros()
280                    < length as usize
281            },
282            |values, validity, start, length| {
283                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
284                let validity =
285                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
286                values.intersects_with(validity)
287            },
288            |values, validity, start, length| {
289                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
290                let validity =
291                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
292
293                if values.intersects_with(validity) {
294                    Some(true)
295                } else if validity.unset_bits() == 0 {
296                    Some(false)
297                } else {
298                    None
299                }
300            },
301        )
302    }
303
304    /// # Safety
305    ///
306    /// Groups should be in correct.
307    pub unsafe fn agg_all(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
308        self.bool_agg(
309            groups,
310            ignore_nulls,
311            |values, idxs| {
312                idxs.iter()
313                    .all(|i| unsafe { values.get_bit_unchecked(*i as usize) })
314            },
315            |values, validity, idxs| {
316                idxs.iter().all(|i| unsafe {
317                    !validity.get_bit_unchecked(*i as usize) | values.get_bit_unchecked(*i as usize)
318                })
319            },
320            |values, validity, idxs| {
321                let mut saw_null = false;
322                for i in idxs.iter() {
323                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
324                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
325
326                    if is_valid & !is_true {
327                        return Some(false);
328                    }
329                    saw_null |= !is_valid;
330                }
331                (!saw_null).then_some(true)
332            },
333            |values, start, length| {
334                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
335                values.unset_bits() == 0
336            },
337            |values, validity, start, length| {
338                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
339                let validity =
340                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
341                values.num_intersections_with(validity) == validity.set_bits()
342            },
343            |values, validity, start, length| {
344                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
345                let validity =
346                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
347
348                let num_non_nulls = validity.set_bits();
349
350                if values.num_intersections_with(validity) < num_non_nulls {
351                    Some(false)
352                } else if num_non_nulls < values.len() {
353                    None
354                } else {
355                    Some(true)
356                }
357            },
358        )
359    }
360}