polars_core/frame/group_by/aggregations/
boolean.rs

1use std::borrow::Cow;
2
3use super::*;
4use crate::chunked_array::cast::CastOptions;
5
6pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
7where
8    F: Fn((IdxSize, &IdxVec)) -> Option<bool> + Send + Sync,
9{
10    let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
11    ca.into_series()
12}
13
14pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
15where
16    F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
17{
18    let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
19    ca.into_series()
20}
21
22#[cfg(feature = "bitwise")]
23unsafe fn bitwise_agg(
24    ca: &BooleanChunked,
25    groups: &GroupsType,
26    f: fn(&BooleanChunked) -> Option<bool>,
27) -> Series {
28    // Prevent a rechunk for every individual group.
29
30    let s = if groups.len() > 1 {
31        ca.rechunk()
32    } else {
33        Cow::Borrowed(ca)
34    };
35
36    match groups {
37        GroupsType::Idx(groups) => _agg_helper_idx_bool::<_>(groups, |(_, idx)| {
38            debug_assert!(idx.len() <= s.len());
39            if idx.is_empty() {
40                None
41            } else {
42                let take = s.take_unchecked(idx);
43                f(&take)
44            }
45        }),
46        GroupsType::Slice { groups, .. } => _agg_helper_slice_bool::<_>(groups, |[first, len]| {
47            debug_assert!(len <= s.len() as IdxSize);
48            if len == 0 {
49                None
50            } else {
51                let take = _slice_from_offsets(&s, first, len);
52                f(&take)
53            }
54        }),
55    }
56}
57
58#[cfg(feature = "bitwise")]
59impl BooleanChunked {
60    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
61        bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce)
62    }
63
64    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
65        bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce)
66    }
67
68    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
69        bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce)
70    }
71}
72
73impl BooleanChunked {
74    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
75        // faster paths
76        match (self.is_sorted_flag(), self.null_count()) {
77            (IsSorted::Ascending, 0) => {
78                return self.clone().into_series().agg_first(groups);
79            },
80            (IsSorted::Descending, 0) => {
81                return self.clone().into_series().agg_last(groups);
82            },
83            _ => {},
84        }
85        let ca_self = self.rechunk();
86        let arr = ca_self.downcast_iter().next().unwrap();
87        let no_nulls = arr.null_count() == 0;
88        match groups {
89            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
90                debug_assert!(idx.len() <= self.len());
91                if idx.is_empty() {
92                    None
93                } else if idx.len() == 1 {
94                    arr.get(first as usize)
95                } else if no_nulls {
96                    take_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
97                } else {
98                    take_min_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
99                }
100            }),
101            GroupsType::Slice {
102                groups: groups_slice,
103                ..
104            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
105                debug_assert!(len <= self.len() as IdxSize);
106                match len {
107                    0 => None,
108                    1 => self.get(first as usize),
109                    _ => {
110                        let arr_group = _slice_from_offsets(self, first, len);
111                        arr_group.min()
112                    },
113                }
114            }),
115        }
116    }
117    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
118        // faster paths
119        match (self.is_sorted_flag(), self.null_count()) {
120            (IsSorted::Ascending, 0) => {
121                return self.clone().into_series().agg_last(groups);
122            },
123            (IsSorted::Descending, 0) => {
124                return self.clone().into_series().agg_first(groups);
125            },
126            _ => {},
127        }
128
129        let ca_self = self.rechunk();
130        let arr = ca_self.downcast_iter().next().unwrap();
131        let no_nulls = arr.null_count() == 0;
132        match groups {
133            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
134                debug_assert!(idx.len() <= self.len());
135                if idx.is_empty() {
136                    None
137                } else if idx.len() == 1 {
138                    self.get(first as usize)
139                } else if no_nulls {
140                    take_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
141                } else {
142                    take_max_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
143                }
144            }),
145            GroupsType::Slice {
146                groups: groups_slice,
147                ..
148            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
149                debug_assert!(len <= self.len() as IdxSize);
150                match len {
151                    0 => None,
152                    1 => self.get(first as usize),
153                    _ => {
154                        let arr_group = _slice_from_offsets(self, first, len);
155                        arr_group.max()
156                    },
157                }
158            }),
159        }
160    }
161    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
162        self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)
163            .unwrap()
164            .agg_sum(groups)
165    }
166}