polars_core/frame/group_by/aggregations/
agg_list.rs

1use arrow::offset::Offsets;
2
3use super::*;
4use crate::chunked_array::builder::ListNullChunkedBuilder;
5use crate::series::implementations::null::NullChunked;
6
7pub trait AggList {
8    /// # Safety
9    ///
10    /// groups should be in bounds
11    unsafe fn agg_list(&self, _groups: &GroupsType) -> Series;
12}
13
14impl<T: PolarsNumericType> AggList for ChunkedArray<T> {
15    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
16        let ca = self.rechunk();
17
18        match groups {
19            GroupsType::Idx(groups) => {
20                let mut can_fast_explode = true;
21
22                let arr = ca.downcast_iter().next().unwrap();
23                let values = arr.values();
24
25                let mut offsets = Vec::<i64>::with_capacity(groups.len() + 1);
26                let mut length_so_far = 0i64;
27                offsets.push(length_so_far);
28
29                let mut list_values = Vec::<T::Native>::with_capacity(self.len());
30                groups.iter().for_each(|(_, idx)| {
31                    let idx_len = idx.len();
32                    if idx_len == 0 {
33                        can_fast_explode = false;
34                    }
35
36                    length_so_far += idx_len as i64;
37                    // SAFETY:
38                    // group tuples are in bounds
39                    {
40                        list_values.extend(idx.iter().map(|idx| {
41                            debug_assert!((*idx as usize) < values.len());
42                            *values.get_unchecked(*idx as usize)
43                        }));
44                        // SAFETY:
45                        // we know that offsets has allocated enough slots
46                        offsets.push_unchecked(length_so_far);
47                    }
48                });
49
50                let validity = if arr.null_count() > 0 {
51                    let old_validity = arr.validity().unwrap();
52                    let mut validity = MutableBitmap::from_len_set(list_values.len());
53
54                    let mut count = 0;
55                    groups.iter().for_each(|(_, idx)| {
56                        for i in idx.as_slice() {
57                            if !old_validity.get_bit_unchecked(*i as usize) {
58                                validity.set_unchecked(count, false);
59                            }
60                            count += 1;
61                        }
62                    });
63                    Some(validity.into())
64                } else {
65                    None
66                };
67
68                let array = PrimitiveArray::new(
69                    T::get_static_dtype().to_arrow(CompatLevel::newest()),
70                    list_values.into(),
71                    validity,
72                );
73                let dtype = ListArray::<i64>::default_datatype(
74                    T::get_static_dtype().to_arrow(CompatLevel::newest()),
75                );
76                // SAFETY:
77                // offsets are monotonically increasing
78                let arr = ListArray::<i64>::new(
79                    dtype,
80                    Offsets::new_unchecked(offsets).into(),
81                    Box::new(array),
82                    None,
83                );
84
85                let mut ca = ListChunked::with_chunk(self.name().clone(), arr);
86                if can_fast_explode {
87                    ca.set_fast_explode()
88                }
89                ca.into()
90            },
91            GroupsType::Slice { groups, .. } => {
92                let mut can_fast_explode = true;
93                let arr = ca.downcast_iter().next().unwrap();
94                let values = arr.values();
95
96                let mut offsets = Vec::<i64>::with_capacity(groups.len() + 1);
97                let mut length_so_far = 0i64;
98                offsets.push(length_so_far);
99
100                let mut list_values = Vec::<T::Native>::with_capacity(self.len());
101                groups.iter().for_each(|&[first, len]| {
102                    if len == 0 {
103                        can_fast_explode = false;
104                    }
105
106                    length_so_far += len as i64;
107                    list_values.extend_from_slice(&values[first as usize..(first + len) as usize]);
108                    {
109                        // SAFETY:
110                        // we know that offsets has allocated enough slots
111                        offsets.push_unchecked(length_so_far);
112                    }
113                });
114
115                let validity = if arr.null_count() > 0 {
116                    let old_validity = arr.validity().unwrap();
117                    let mut validity = MutableBitmap::from_len_set(list_values.len());
118
119                    let mut count = 0;
120                    groups.iter().for_each(|[first, len]| {
121                        for i in *first..(*first + *len) {
122                            if !old_validity.get_bit_unchecked(i as usize) {
123                                validity.set_unchecked(count, false)
124                            }
125                            count += 1;
126                        }
127                    });
128                    Some(validity.into())
129                } else {
130                    None
131                };
132
133                let array = PrimitiveArray::new(
134                    T::get_static_dtype().to_arrow(CompatLevel::newest()),
135                    list_values.into(),
136                    validity,
137                );
138                let dtype = ListArray::<i64>::default_datatype(
139                    T::get_static_dtype().to_arrow(CompatLevel::newest()),
140                );
141                let arr = ListArray::<i64>::new(
142                    dtype,
143                    Offsets::new_unchecked(offsets).into(),
144                    Box::new(array),
145                    None,
146                );
147                let mut ca = ListChunked::with_chunk(self.name().clone(), arr);
148                if can_fast_explode {
149                    ca.set_fast_explode()
150                }
151                ca.into()
152            },
153        }
154    }
155}
156
157impl AggList for NullChunked {
158    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
159        match groups {
160            GroupsType::Idx(groups) => {
161                let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len());
162                for idx in groups.all().iter() {
163                    builder.append_with_len(idx.len());
164                }
165                builder.finish().into_series()
166            },
167            GroupsType::Slice { groups, .. } => {
168                let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len());
169                for [_, len] in groups {
170                    builder.append_with_len(*len as usize);
171                }
172                builder.finish().into_series()
173            },
174        }
175    }
176}
177
178impl AggList for BooleanChunked {
179    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
180        agg_list_by_gather_and_offsets(self, groups)
181    }
182}
183
184impl AggList for StringChunked {
185    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
186        agg_list_by_gather_and_offsets(self, groups)
187    }
188}
189
190impl AggList for BinaryChunked {
191    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
192        agg_list_by_gather_and_offsets(self, groups)
193    }
194}
195
196impl AggList for ListChunked {
197    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
198        agg_list_by_gather_and_offsets(self, groups)
199    }
200}
201
202#[cfg(feature = "dtype-array")]
203impl AggList for ArrayChunked {
204    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
205        agg_list_by_gather_and_offsets(self, groups)
206    }
207}
208
209#[cfg(feature = "object")]
210impl<T: PolarsObject> AggList for ObjectChunked<T> {
211    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
212        let mut can_fast_explode = true;
213        let mut offsets = Vec::<i64>::with_capacity(groups.len() + 1);
214        let mut length_so_far = 0i64;
215        offsets.push(length_so_far);
216
217        //  we know that iterators length
218        let iter = {
219            groups
220                .iter()
221                .flat_map(|indicator| {
222                    let (group_vals, len) = match indicator {
223                        GroupsIndicator::Idx((_first, idx)) => {
224                            // SAFETY:
225                            // group tuples always in bounds
226                            let group_vals = self.take_unchecked(idx);
227
228                            (group_vals, idx.len() as IdxSize)
229                        },
230                        GroupsIndicator::Slice([first, len]) => {
231                            let group_vals = _slice_from_offsets(self, first, len);
232
233                            (group_vals, len)
234                        },
235                    };
236
237                    if len == 0 {
238                        can_fast_explode = false;
239                    }
240                    length_so_far += len as i64;
241                    // SAFETY:
242                    // we know that offsets has allocated enough slots
243                    offsets.push_unchecked(length_so_far);
244
245                    let arr = group_vals.downcast_iter().next().unwrap().clone();
246                    arr.into_iter_cloned()
247                })
248                .trust_my_length(self.len())
249        };
250
251        let mut pe = create_extension(iter);
252
253        // SAFETY: This is safe because we just created the PolarsExtension
254        // meaning that the sentinel is heap allocated and the dereference of
255        // the pointer does not fail.
256        pe.set_to_series_fn::<T>();
257        let extension_array = Box::new(pe.take_and_forget()) as ArrayRef;
258        let extension_dtype = extension_array.dtype();
259
260        let dtype = ListArray::<i64>::default_datatype(extension_dtype.clone());
261        // SAFETY: offsets are monotonically increasing.
262        let arr = ListArray::<i64>::new(
263            dtype,
264            Offsets::new_unchecked(offsets).into(),
265            extension_array,
266            None,
267        );
268        let mut listarr = ListChunked::with_chunk(self.name().clone(), arr);
269        if can_fast_explode {
270            listarr.set_fast_explode()
271        }
272        listarr.into_series()
273    }
274}
275
276#[cfg(feature = "dtype-struct")]
277impl AggList for StructChunked {
278    unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
279        let ca = self.clone();
280        let (gather, offsets, can_fast_explode) = groups.prepare_list_agg(self.len());
281
282        let gathered = if let Some(gather) = gather {
283            let out = ca.into_series().take_unchecked(&gather);
284            out.struct_().unwrap().clone()
285        } else {
286            ca.rechunk().into_owned()
287        };
288
289        let arr = gathered.chunks()[0].clone();
290        let dtype = LargeListArray::default_datatype(arr.dtype().clone());
291
292        let mut chunk = ListChunked::with_chunk(
293            self.name().clone(),
294            LargeListArray::new(dtype, offsets, arr, None),
295        );
296        chunk.set_dtype(DataType::List(Box::new(self.dtype().clone())));
297        if can_fast_explode {
298            chunk.set_fast_explode()
299        }
300
301        chunk.into_series()
302    }
303}
304
305unsafe fn agg_list_by_gather_and_offsets<T: PolarsDataType>(
306    ca: &ChunkedArray<T>,
307    groups: &GroupsType,
308) -> Series
309where
310    ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
311{
312    let (gather, offsets, can_fast_explode) = groups.prepare_list_agg(ca.len());
313
314    let gathered = if let Some(gather) = gather {
315        ca.take_unchecked(&gather)
316    } else {
317        ca.clone()
318    };
319
320    let arr = gathered.chunks()[0].clone();
321    let dtype = LargeListArray::default_datatype(arr.dtype().clone());
322
323    let mut chunk = ListChunked::with_chunk(
324        ca.name().clone(),
325        LargeListArray::new(dtype, offsets, arr, None),
326    );
327    chunk.set_dtype(DataType::List(Box::new(ca.dtype().clone())));
328    if can_fast_explode {
329        chunk.set_fast_explode()
330    }
331
332    chunk.into_series()
333}