Skip to main content

polars_core/frame/group_by/aggregations/
boolean.rs

1use arrow::bitmap::bitmask::BitMask;
2
3use super::*;
4use crate::chunked_array::cast::CastOptions;
5use crate::chunked_array::{arg_max_bool, arg_min_bool};
6
7pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
8where
9    F: Fn((IdxSize, &IdxVec)) -> Option<bool> + Send + Sync,
10{
11    let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
12    ca.into_series()
13}
14
15pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
16where
17    F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
18{
19    let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
20    ca.into_series()
21}
22
23#[cfg(feature = "bitwise")]
24impl BooleanChunked {
25    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> BooleanChunked {
26        self.agg_all(groups, true)
27    }
28
29    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> BooleanChunked {
30        self.agg_any(groups, true)
31    }
32
33    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> BooleanChunked {
34        self.bool_agg(
35            groups,
36            true,
37            |values, idxs| {
38                idxs.iter()
39                    .map(|i| {
40                        <IdxSize as From<bool>>::from(unsafe {
41                            values.get_bit_unchecked(*i as usize)
42                        })
43                    })
44                    .sum::<IdxSize>()
45                    % 2
46                    == 1
47            },
48            |values, validity, idxs| {
49                idxs.iter()
50                    .map(|i| {
51                        <IdxSize as From<bool>>::from(unsafe {
52                            validity.get_bit_unchecked(*i as usize)
53                                & values.get_bit_unchecked(*i as usize)
54                        })
55                    })
56                    .sum::<IdxSize>()
57                    % 2
58                    == 1
59            },
60            |_, _, _| unreachable!(),
61            |values, start, length| {
62                unsafe { values.sliced_unchecked(start as usize, length as usize) }.set_bits() % 2
63                    == 1
64            },
65            |values, validity, start, length| {
66                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
67                let validity =
68                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
69                values.num_intersections_with(validity) % 2 == 1
70            },
71            |_, _, _, _| unreachable!(),
72        )
73    }
74}
75
76impl BooleanChunked {
77    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
78        // faster paths
79        if groups.is_sorted_flag() {
80            match self.is_sorted_flag() {
81                IsSorted::Ascending => {
82                    return self.clone().into_series().agg_first_non_null(groups);
83                },
84                IsSorted::Descending => {
85                    return self.clone().into_series().agg_last_non_null(groups);
86                },
87                _ => {},
88            }
89        }
90        let ca_self = self.rechunk();
91        let arr = ca_self.downcast_iter().next().unwrap();
92        let no_nulls = arr.null_count() == 0;
93        match groups {
94            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
95                debug_assert!(idx.len() <= self.len());
96                if idx.is_empty() {
97                    None
98                } else if idx.len() == 1 {
99                    arr.get(first as usize)
100                } else if no_nulls {
101                    take_arg_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
102                        .map(|p| arr.value_unchecked(idx[p] as usize))
103                } else {
104                    take_arg_min_bool_iter_unchecked_nulls(arr, idx2usize(idx))
105                        .map(|p| arr.value_unchecked(idx[p] as usize))
106                }
107            }),
108            GroupsType::Slice {
109                groups: groups_slice,
110                ..
111            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
112                debug_assert!(len <= self.len() as IdxSize);
113                match len {
114                    0 => None,
115                    1 => self.get(first as usize),
116                    _ => {
117                        let arr_group = _slice_from_offsets(self, first, len);
118                        arr_group.min()
119                    },
120                }
121            }),
122        }
123    }
124    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
125        // faster paths
126        if groups.is_sorted_flag() {
127            match self.is_sorted_flag() {
128                IsSorted::Ascending => return self.clone().into_series().agg_last_non_null(groups),
129                IsSorted::Descending => {
130                    return self.clone().into_series().agg_first_non_null(groups);
131                },
132                _ => {},
133            }
134        }
135
136        let ca_self = self.rechunk();
137        let arr = ca_self.downcast_iter().next().unwrap();
138        let no_nulls = arr.null_count() == 0;
139        match groups {
140            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
141                debug_assert!(idx.len() <= self.len());
142                if idx.is_empty() {
143                    None
144                } else if idx.len() == 1 {
145                    self.get(first as usize)
146                } else if no_nulls {
147                    take_arg_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
148                        .map(|p| arr.value_unchecked(idx[p] as usize))
149                } else {
150                    take_arg_max_bool_iter_unchecked_nulls(arr, idx2usize(idx))
151                        .map(|p| arr.value_unchecked(idx[p] as usize))
152                }
153            }),
154            GroupsType::Slice {
155                groups: groups_slice,
156                ..
157            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
158                debug_assert!(len <= self.len() as IdxSize);
159                match len {
160                    0 => None,
161                    1 => self.get(first as usize),
162                    _ => {
163                        let arr_group = _slice_from_offsets(self, first, len);
164                        arr_group.max()
165                    },
166                }
167            }),
168        }
169    }
170
171    pub(crate) unsafe fn agg_arg_min(&self, groups: &GroupsType) -> Series {
172        // faster paths
173        if groups.is_sorted_flag() {
174            match self.is_sorted_flag() {
175                IsSorted::Ascending => {
176                    return self.clone().into_series().agg_arg_first_non_null(groups);
177                },
178                IsSorted::Descending => {
179                    return self.clone().into_series().agg_arg_last_non_null(groups);
180                },
181                _ => {},
182            }
183        }
184
185        let ca_self = self.rechunk();
186        let arr = ca_self.downcast_iter().next().unwrap();
187        let no_nulls = arr.null_count() == 0;
188        match groups {
189            GroupsType::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
190                debug_assert!(idx.len() <= ca_self.len());
191                if idx.is_empty() {
192                    None
193                } else if idx.len() == 1 {
194                    arr.get(idx[0] as usize).map(|_| 0)
195                } else if no_nulls {
196                    take_arg_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
197                        .map(|p| p as IdxSize)
198                } else {
199                    take_arg_min_bool_iter_unchecked_nulls(arr, idx2usize(idx))
200                        .map(|p| p as IdxSize)
201                }
202            }),
203            GroupsType::Slice {
204                groups: groups_slice,
205                ..
206            } => _agg_helper_slice::<IdxType, _>(groups_slice, |[first, len]| {
207                debug_assert!(len <= self.len() as IdxSize);
208                match len {
209                    0 => None,
210                    1 => self.get(first as usize).map(|_| 0),
211                    _ => {
212                        let group_ca = _slice_from_offsets(self, first, len);
213                        arg_min_bool(&group_ca).map(|p| p as IdxSize)
214                    },
215                }
216            }),
217        }
218    }
219
220    pub(crate) unsafe fn agg_arg_max(&self, groups: &GroupsType) -> Series {
221        // faster paths
222        if groups.is_sorted_flag() {
223            match self.is_sorted_flag() {
224                IsSorted::Ascending => {
225                    return self.clone().into_series().agg_arg_last_non_null(groups);
226                },
227                IsSorted::Descending => {
228                    return self.clone().into_series().agg_arg_first_non_null(groups);
229                },
230                _ => {},
231            }
232        }
233
234        let ca_self = self.rechunk();
235        let arr = ca_self.downcast_iter().next().unwrap();
236        let no_nulls = arr.null_count() == 0;
237        match groups {
238            GroupsType::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
239                debug_assert!(idx.len() <= ca_self.len());
240                if idx.is_empty() {
241                    None
242                } else if idx.len() == 1 {
243                    arr.get(idx[0] as usize).map(|_| 0)
244                } else if no_nulls {
245                    take_arg_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
246                        .map(|p| p as IdxSize)
247                } else {
248                    take_arg_max_bool_iter_unchecked_nulls(arr, idx2usize(idx))
249                        .map(|p| p as IdxSize)
250                }
251            }),
252            GroupsType::Slice {
253                groups: groups_slice,
254                ..
255            } => _agg_helper_slice::<IdxType, _>(groups_slice, |[first, len]| {
256                debug_assert!(len <= self.len() as IdxSize);
257                match len {
258                    0 => None,
259                    1 => self.get(first as usize).map(|_| 0),
260                    _ => {
261                        let group_ca = _slice_from_offsets(self, first, len);
262                        arg_max_bool(&group_ca).map(|p| p as IdxSize)
263                    },
264                }
265            }),
266        }
267    }
268
269    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
270        self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)
271            .unwrap()
272            .agg_sum(groups)
273    }
274
275    /// # Safety
276    ///
277    /// Groups should be in correct.
278    #[expect(clippy::too_many_arguments)]
279    unsafe fn bool_agg(
280        &self,
281        groups: &GroupsType,
282        ignore_nulls: bool,
283
284        idx_no_valid: impl Fn(BitMask, &[IdxSize]) -> bool + Send + Sync,
285        idx_validity: impl Fn(BitMask, BitMask, &[IdxSize]) -> bool + Send + Sync,
286        idx_kleene: impl Fn(BitMask, BitMask, &[IdxSize]) -> Option<bool> + Send + Sync,
287
288        slice_no_valid: impl Fn(BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
289        slice_validity: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
290        slice_kleene: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> Option<bool> + Send + Sync,
291    ) -> BooleanChunked {
292        let name = self.name().clone();
293        let values = self.rechunk();
294        let values = values.downcast_as_array();
295
296        let ca: BooleanChunked = POOL.install(|| {
297            let validity = values
298                .validity()
299                .filter(|v| v.unset_bits() > 0)
300                .map(BitMask::from_bitmap);
301            let values = BitMask::from_bitmap(values.values());
302
303            if !ignore_nulls && let Some(validity) = validity {
304                match groups {
305                    GroupsType::Idx(idx) => idx
306                        .into_par_iter()
307                        .map(|(_, idx)| idx_kleene(values, validity, idx))
308                        .collect(),
309                    GroupsType::Slice {
310                        groups,
311                        overlapping: _,
312                        monotonic: _,
313                    } => groups
314                        .into_par_iter()
315                        .map(|[start, length]| slice_kleene(values, validity, *start, *length))
316                        .collect(),
317                }
318            } else {
319                match groups {
320                    GroupsType::Idx(idx) => match validity {
321                        None => idx
322                            .into_par_iter()
323                            .map(|(_, idx)| idx_no_valid(values, idx))
324                            .collect(),
325                        Some(validity) => idx
326                            .into_par_iter()
327                            .map(|(_, idx)| idx_validity(values, validity, idx))
328                            .collect(),
329                    },
330                    GroupsType::Slice {
331                        groups,
332                        overlapping: _,
333                        monotonic: _,
334                    } => match validity {
335                        None => groups
336                            .into_par_iter()
337                            .map(|[start, length]| slice_no_valid(values, *start, *length))
338                            .collect(),
339                        Some(validity) => groups
340                            .into_par_iter()
341                            .map(|[start, length]| {
342                                slice_validity(values, validity, *start, *length)
343                            })
344                            .collect(),
345                    },
346                }
347            }
348        });
349        ca.with_name(name)
350    }
351
352    /// # Safety
353    ///
354    /// Groups should be in correct.
355    pub unsafe fn agg_any(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
356        self.bool_agg(
357            groups,
358            ignore_nulls,
359            |values, idxs| {
360                idxs.iter()
361                    .any(|i| unsafe { values.get_bit_unchecked(*i as usize) })
362            },
363            |values, validity, idxs| {
364                idxs.iter().any(|i| unsafe {
365                    validity.get_bit_unchecked(*i as usize) & values.get_bit_unchecked(*i as usize)
366                })
367            },
368            |values, validity, idxs| {
369                let mut saw_null = false;
370                for i in idxs.iter() {
371                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
372                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
373
374                    if is_valid & is_true {
375                        return Some(true);
376                    }
377                    saw_null |= !is_valid;
378                }
379                (!saw_null).then_some(false)
380            },
381            |values, start, length| {
382                unsafe { values.sliced_unchecked(start as usize, length as usize) }.leading_zeros()
383                    < length as usize
384            },
385            |values, validity, start, length| {
386                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
387                let validity =
388                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
389                values.intersects_with(validity)
390            },
391            |values, validity, start, length| {
392                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
393                let validity =
394                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
395
396                if values.intersects_with(validity) {
397                    Some(true)
398                } else if validity.unset_bits() == 0 {
399                    Some(false)
400                } else {
401                    None
402                }
403            },
404        )
405    }
406
407    /// # Safety
408    ///
409    /// Groups should be in correct.
410    pub unsafe fn agg_all(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
411        self.bool_agg(
412            groups,
413            ignore_nulls,
414            |values, idxs| {
415                idxs.iter()
416                    .all(|i| unsafe { values.get_bit_unchecked(*i as usize) })
417            },
418            |values, validity, idxs| {
419                idxs.iter().all(|i| unsafe {
420                    !validity.get_bit_unchecked(*i as usize) | values.get_bit_unchecked(*i as usize)
421                })
422            },
423            |values, validity, idxs| {
424                let mut saw_null = false;
425                for i in idxs.iter() {
426                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
427                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
428
429                    if is_valid & !is_true {
430                        return Some(false);
431                    }
432                    saw_null |= !is_valid;
433                }
434                (!saw_null).then_some(true)
435            },
436            |values, start, length| {
437                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
438                values.unset_bits() == 0
439            },
440            |values, validity, start, length| {
441                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
442                let validity =
443                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
444                values.num_intersections_with(validity) == validity.set_bits()
445            },
446            |values, validity, start, length| {
447                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
448                let validity =
449                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
450
451                let num_non_nulls = validity.set_bits();
452
453                if values.num_intersections_with(validity) < num_non_nulls {
454                    Some(false)
455                } else if num_non_nulls < values.len() {
456                    None
457                } else {
458                    Some(true)
459                }
460            },
461        )
462    }
463}