polars_core/frame/group_by/
position.rs

1use std::mem::ManuallyDrop;
2use std::ops::{Deref, DerefMut};
3
4use arrow::offset::OffsetsBuffer;
5use polars_utils::idx_vec::IdxVec;
6use rayon::iter::plumbing::UnindexedConsumer;
7use rayon::prelude::*;
8
9use crate::POOL;
10use crate::prelude::*;
11use crate::utils::{NoNull, flatten, slice_slice};
12
13/// Indexes of the groups, the first index is stored separately.
14/// this make sorting fast.
15#[derive(Debug, Clone, PartialEq, Eq, Default)]
16pub struct GroupsIdx {
17    pub(crate) sorted: bool,
18    first: Vec<IdxSize>,
19    all: Vec<IdxVec>,
20}
21
22pub type IdxItem = (IdxSize, IdxVec);
23pub type BorrowIdxItem<'a> = (IdxSize, &'a IdxVec);
24
25impl Drop for GroupsIdx {
26    fn drop(&mut self) {
27        let v = std::mem::take(&mut self.all);
28        // ~65k took approximately 1ms on local machine, so from that point we drop on other thread
29        // to stop query from being blocked
30        #[cfg(not(target_family = "wasm"))]
31        if v.len() > 1 << 16 {
32            std::thread::spawn(move || drop(v));
33        } else {
34            drop(v);
35        }
36
37        #[cfg(target_family = "wasm")]
38        drop(v);
39    }
40}
41
42impl From<Vec<IdxItem>> for GroupsIdx {
43    fn from(v: Vec<IdxItem>) -> Self {
44        v.into_iter().collect()
45    }
46}
47
48impl From<Vec<Vec<IdxItem>>> for GroupsIdx {
49    fn from(v: Vec<Vec<IdxItem>>) -> Self {
50        // single threaded flatten: 10% faster than `iter().flatten().collect()
51        // this is the multi-threaded impl of that
52        let (cap, offsets) = flatten::cap_and_offsets(&v);
53        let mut first = Vec::with_capacity(cap);
54        let first_ptr = first.as_ptr() as usize;
55        let mut all = Vec::with_capacity(cap);
56        let all_ptr = all.as_ptr() as usize;
57
58        POOL.install(|| {
59            v.into_par_iter()
60                .zip(offsets)
61                .for_each(|(mut inner, offset)| {
62                    unsafe {
63                        let first = (first_ptr as *const IdxSize as *mut IdxSize).add(offset);
64                        let all = (all_ptr as *const IdxVec as *mut IdxVec).add(offset);
65
66                        let inner_ptr = inner.as_mut_ptr();
67                        for i in 0..inner.len() {
68                            let (first_val, vals) = std::ptr::read(inner_ptr.add(i));
69                            std::ptr::write(first.add(i), first_val);
70                            std::ptr::write(all.add(i), vals);
71                        }
72                        // set len to 0 so that the contents will not get dropped
73                        // they are moved to `first` and `all`
74                        inner.set_len(0);
75                    }
76                });
77        });
78        unsafe {
79            all.set_len(cap);
80            first.set_len(cap);
81        }
82        GroupsIdx {
83            sorted: false,
84            first,
85            all,
86        }
87    }
88}
89
90impl GroupsIdx {
91    pub fn new(first: Vec<IdxSize>, all: Vec<IdxVec>, sorted: bool) -> Self {
92        Self { sorted, first, all }
93    }
94
95    pub fn sort(&mut self) {
96        if self.sorted {
97            return;
98        }
99        let mut idx = 0;
100        let first = std::mem::take(&mut self.first);
101        // store index and values so that we can sort those
102        let mut idx_vals = first
103            .into_iter()
104            .map(|v| {
105                let out = [idx, v];
106                idx += 1;
107                out
108            })
109            .collect_trusted::<Vec<_>>();
110        idx_vals.sort_unstable_by_key(|v| v[1]);
111
112        let take_first = || idx_vals.iter().map(|v| v[1]).collect_trusted::<Vec<_>>();
113        let take_all = || {
114            idx_vals
115                .iter()
116                .map(|v| unsafe {
117                    let idx = v[0] as usize;
118                    std::mem::take(self.all.get_unchecked_mut(idx))
119                })
120                .collect_trusted::<Vec<_>>()
121        };
122        let (first, all) = POOL.install(|| rayon::join(take_first, take_all));
123        self.first = first;
124        self.all = all;
125        self.sorted = true
126    }
127    pub fn is_sorted_flag(&self) -> bool {
128        self.sorted
129    }
130
131    pub fn iter(
132        &self,
133    ) -> std::iter::Zip<std::iter::Copied<std::slice::Iter<IdxSize>>, std::slice::Iter<IdxVec>>
134    {
135        self.into_iter()
136    }
137
138    pub fn all(&self) -> &[IdxVec] {
139        &self.all
140    }
141
142    pub fn first(&self) -> &[IdxSize] {
143        &self.first
144    }
145
146    pub fn first_mut(&mut self) -> &mut Vec<IdxSize> {
147        &mut self.first
148    }
149
150    pub(crate) fn len(&self) -> usize {
151        self.first.len()
152    }
153
154    pub(crate) unsafe fn get_unchecked(&self, index: usize) -> BorrowIdxItem {
155        let first = *self.first.get_unchecked(index);
156        let all = self.all.get_unchecked(index);
157        (first, all)
158    }
159}
160
161impl FromIterator<IdxItem> for GroupsIdx {
162    fn from_iter<T: IntoIterator<Item = IdxItem>>(iter: T) -> Self {
163        let (first, all) = iter.into_iter().unzip();
164        GroupsIdx {
165            sorted: false,
166            first,
167            all,
168        }
169    }
170}
171
172impl<'a> IntoIterator for &'a GroupsIdx {
173    type Item = BorrowIdxItem<'a>;
174    type IntoIter = std::iter::Zip<
175        std::iter::Copied<std::slice::Iter<'a, IdxSize>>,
176        std::slice::Iter<'a, IdxVec>,
177    >;
178
179    fn into_iter(self) -> Self::IntoIter {
180        self.first.iter().copied().zip(self.all.iter())
181    }
182}
183
184impl IntoIterator for GroupsIdx {
185    type Item = IdxItem;
186    type IntoIter = std::iter::Zip<std::vec::IntoIter<IdxSize>, std::vec::IntoIter<IdxVec>>;
187
188    fn into_iter(mut self) -> Self::IntoIter {
189        let first = std::mem::take(&mut self.first);
190        let all = std::mem::take(&mut self.all);
191        first.into_iter().zip(all)
192    }
193}
194
195impl FromParallelIterator<IdxItem> for GroupsIdx {
196    fn from_par_iter<I>(par_iter: I) -> Self
197    where
198        I: IntoParallelIterator<Item = IdxItem>,
199    {
200        let (first, all) = par_iter.into_par_iter().unzip();
201        GroupsIdx {
202            sorted: false,
203            first,
204            all,
205        }
206    }
207}
208
209impl<'a> IntoParallelIterator for &'a GroupsIdx {
210    type Iter = rayon::iter::Zip<
211        rayon::iter::Copied<rayon::slice::Iter<'a, IdxSize>>,
212        rayon::slice::Iter<'a, IdxVec>,
213    >;
214    type Item = BorrowIdxItem<'a>;
215
216    fn into_par_iter(self) -> Self::Iter {
217        self.first.par_iter().copied().zip(self.all.par_iter())
218    }
219}
220
221impl IntoParallelIterator for GroupsIdx {
222    type Iter = rayon::iter::Zip<rayon::vec::IntoIter<IdxSize>, rayon::vec::IntoIter<IdxVec>>;
223    type Item = IdxItem;
224
225    fn into_par_iter(mut self) -> Self::Iter {
226        let first = std::mem::take(&mut self.first);
227        let all = std::mem::take(&mut self.all);
228        first.into_par_iter().zip(all.into_par_iter())
229    }
230}
231
232/// Every group is indicated by an array where the
233///  - first value is an index to the start of the group
234///  - second value is the length of the group
235///
236/// Only used when group values are stored together
237///
238/// This type should have the invariant that it is always sorted in ascending order.
239pub type GroupsSlice = Vec<[IdxSize; 2]>;
240
241#[derive(Debug, Clone, PartialEq, Eq)]
242pub enum GroupsType {
243    Idx(GroupsIdx),
244    /// Slice is always sorted in ascending order.
245    Slice {
246        // the groups slices
247        groups: GroupsSlice,
248        // indicates if we do a rolling group_by
249        rolling: bool,
250    },
251}
252
253impl Default for GroupsType {
254    fn default() -> Self {
255        GroupsType::Idx(GroupsIdx::default())
256    }
257}
258
259impl GroupsType {
260    pub fn into_idx(self) -> GroupsIdx {
261        match self {
262            GroupsType::Idx(groups) => groups,
263            GroupsType::Slice { groups, .. } => {
264                polars_warn!(
265                    "Had to reallocate groups, missed an optimization opportunity. Please open an issue."
266                );
267                groups
268                    .iter()
269                    .map(|&[first, len]| (first, (first..first + len).collect::<IdxVec>()))
270                    .collect()
271            },
272        }
273    }
274
275    pub(crate) fn prepare_list_agg(
276        &self,
277        total_len: usize,
278    ) -> (Option<IdxCa>, OffsetsBuffer<i64>, bool) {
279        let mut can_fast_explode = true;
280        match self {
281            GroupsType::Idx(groups) => {
282                let mut list_offset = Vec::with_capacity(self.len() + 1);
283                let mut gather_offsets = Vec::with_capacity(total_len);
284
285                let mut len_so_far = 0i64;
286                list_offset.push(len_so_far);
287
288                for idx in groups {
289                    let idx = idx.1;
290                    gather_offsets.extend_from_slice(idx);
291                    len_so_far += idx.len() as i64;
292                    list_offset.push(len_so_far);
293                    can_fast_explode &= !idx.is_empty();
294                }
295                unsafe {
296                    (
297                        Some(IdxCa::from_vec(PlSmallStr::EMPTY, gather_offsets)),
298                        OffsetsBuffer::new_unchecked(list_offset.into()),
299                        can_fast_explode,
300                    )
301                }
302            },
303            GroupsType::Slice { groups, .. } => {
304                let mut list_offset = Vec::with_capacity(self.len() + 1);
305                let mut gather_offsets = Vec::with_capacity(total_len);
306                let mut len_so_far = 0i64;
307                list_offset.push(len_so_far);
308
309                for g in groups {
310                    let len = g[1];
311                    let offset = g[0];
312                    gather_offsets.extend(offset..offset + len);
313
314                    len_so_far += len as i64;
315                    list_offset.push(len_so_far);
316                    can_fast_explode &= len > 0;
317                }
318
319                unsafe {
320                    (
321                        Some(IdxCa::from_vec(PlSmallStr::EMPTY, gather_offsets)),
322                        OffsetsBuffer::new_unchecked(list_offset.into()),
323                        can_fast_explode,
324                    )
325                }
326            },
327        }
328    }
329
330    pub fn iter(&self) -> GroupsTypeIter {
331        GroupsTypeIter::new(self)
332    }
333
334    pub fn sort(&mut self) {
335        match self {
336            GroupsType::Idx(groups) => {
337                if !groups.is_sorted_flag() {
338                    groups.sort()
339                }
340            },
341            GroupsType::Slice { .. } => {
342                // invariant of the type
343            },
344        }
345    }
346
347    pub(crate) fn is_sorted_flag(&self) -> bool {
348        match self {
349            GroupsType::Idx(groups) => groups.is_sorted_flag(),
350            GroupsType::Slice { .. } => true,
351        }
352    }
353
354    pub fn take_group_firsts(self) -> Vec<IdxSize> {
355        match self {
356            GroupsType::Idx(mut groups) => std::mem::take(&mut groups.first),
357            GroupsType::Slice { groups, .. } => {
358                groups.into_iter().map(|[first, _len]| first).collect()
359            },
360        }
361    }
362
363    /// # Safety
364    /// This will not do any bounds checks. The caller must ensure
365    /// all groups have members.
366    pub unsafe fn take_group_lasts(self) -> Vec<IdxSize> {
367        match self {
368            GroupsType::Idx(groups) => groups
369                .all
370                .iter()
371                .map(|idx| *idx.get_unchecked(idx.len() - 1))
372                .collect(),
373            GroupsType::Slice { groups, .. } => groups
374                .into_iter()
375                .map(|[first, len]| first + len - 1)
376                .collect(),
377        }
378    }
379
380    pub fn par_iter(&self) -> GroupsTypeParIter {
381        GroupsTypeParIter::new(self)
382    }
383
384    /// Get a reference to the `GroupsIdx`.
385    ///
386    /// # Panic
387    ///
388    /// panics if the groups are a slice.
389    pub fn unwrap_idx(&self) -> &GroupsIdx {
390        match self {
391            GroupsType::Idx(groups) => groups,
392            GroupsType::Slice { .. } => panic!("groups are slices not index"),
393        }
394    }
395
396    /// Get a reference to the `GroupsSlice`.
397    ///
398    /// # Panic
399    ///
400    /// panics if the groups are an idx.
401    pub fn unwrap_slice(&self) -> &GroupsSlice {
402        match self {
403            GroupsType::Slice { groups, .. } => groups,
404            GroupsType::Idx(_) => panic!("groups are index not slices"),
405        }
406    }
407
408    pub fn get(&self, index: usize) -> GroupsIndicator {
409        match self {
410            GroupsType::Idx(groups) => {
411                let first = groups.first[index];
412                let all = &groups.all[index];
413                GroupsIndicator::Idx((first, all))
414            },
415            GroupsType::Slice { groups, .. } => GroupsIndicator::Slice(groups[index]),
416        }
417    }
418
419    /// Get a mutable reference to the `GroupsIdx`.
420    ///
421    /// # Panic
422    ///
423    /// panics if the groups are a slice.
424    pub fn idx_mut(&mut self) -> &mut GroupsIdx {
425        match self {
426            GroupsType::Idx(groups) => groups,
427            GroupsType::Slice { .. } => panic!("groups are slices not index"),
428        }
429    }
430
431    pub fn len(&self) -> usize {
432        match self {
433            GroupsType::Idx(groups) => groups.len(),
434            GroupsType::Slice { groups, .. } => groups.len(),
435        }
436    }
437
438    pub fn is_empty(&self) -> bool {
439        self.len() == 0
440    }
441
442    pub fn group_count(&self) -> IdxCa {
443        match self {
444            GroupsType::Idx(groups) => {
445                let ca: NoNull<IdxCa> = groups
446                    .iter()
447                    .map(|(_first, idx)| idx.len() as IdxSize)
448                    .collect_trusted();
449                ca.into_inner()
450            },
451            GroupsType::Slice { groups, .. } => {
452                let ca: NoNull<IdxCa> = groups.iter().map(|[_first, len]| *len).collect_trusted();
453                ca.into_inner()
454            },
455        }
456    }
457    pub fn as_list_chunked(&self) -> ListChunked {
458        match self {
459            GroupsType::Idx(groups) => groups
460                .iter()
461                .map(|(_first, idx)| {
462                    let ca: NoNull<IdxCa> = idx.iter().map(|&v| v as IdxSize).collect();
463                    ca.into_inner().into_series()
464                })
465                .collect_trusted(),
466            GroupsType::Slice { groups, .. } => groups
467                .iter()
468                .map(|&[first, len]| {
469                    let ca: NoNull<IdxCa> = (first..first + len).collect_trusted();
470                    ca.into_inner().into_series()
471                })
472                .collect_trusted(),
473        }
474    }
475
476    pub fn into_sliceable(self) -> GroupPositions {
477        let len = self.len();
478        slice_groups(Arc::new(self), 0, len)
479    }
480}
481
482impl From<GroupsIdx> for GroupsType {
483    fn from(groups: GroupsIdx) -> Self {
484        GroupsType::Idx(groups)
485    }
486}
487
488pub enum GroupsIndicator<'a> {
489    Idx(BorrowIdxItem<'a>),
490    Slice([IdxSize; 2]),
491}
492
493impl GroupsIndicator<'_> {
494    pub fn len(&self) -> usize {
495        match self {
496            GroupsIndicator::Idx(g) => g.1.len(),
497            GroupsIndicator::Slice([_, len]) => *len as usize,
498        }
499    }
500    pub fn first(&self) -> IdxSize {
501        match self {
502            GroupsIndicator::Idx(g) => g.0,
503            GroupsIndicator::Slice([first, _]) => *first,
504        }
505    }
506    pub fn is_empty(&self) -> bool {
507        self.len() == 0
508    }
509}
510
511pub struct GroupsTypeIter<'a> {
512    vals: &'a GroupsType,
513    len: usize,
514    idx: usize,
515}
516
517impl<'a> GroupsTypeIter<'a> {
518    fn new(vals: &'a GroupsType) -> Self {
519        let len = vals.len();
520        let idx = 0;
521        GroupsTypeIter { vals, len, idx }
522    }
523}
524
525impl<'a> Iterator for GroupsTypeIter<'a> {
526    type Item = GroupsIndicator<'a>;
527
528    fn nth(&mut self, n: usize) -> Option<Self::Item> {
529        self.idx = self.idx.saturating_add(n);
530        self.next()
531    }
532
533    fn next(&mut self) -> Option<Self::Item> {
534        if self.idx >= self.len {
535            return None;
536        }
537
538        let out = unsafe {
539            match self.vals {
540                GroupsType::Idx(groups) => {
541                    let item = groups.get_unchecked(self.idx);
542                    Some(GroupsIndicator::Idx(item))
543                },
544                GroupsType::Slice { groups, .. } => {
545                    Some(GroupsIndicator::Slice(*groups.get_unchecked(self.idx)))
546                },
547            }
548        };
549        self.idx += 1;
550        out
551    }
552}
553
554pub struct GroupsTypeParIter<'a> {
555    vals: &'a GroupsType,
556    len: usize,
557}
558
559impl<'a> GroupsTypeParIter<'a> {
560    fn new(vals: &'a GroupsType) -> Self {
561        let len = vals.len();
562        GroupsTypeParIter { vals, len }
563    }
564}
565
566impl<'a> ParallelIterator for GroupsTypeParIter<'a> {
567    type Item = GroupsIndicator<'a>;
568
569    fn drive_unindexed<C>(self, consumer: C) -> C::Result
570    where
571        C: UnindexedConsumer<Self::Item>,
572    {
573        (0..self.len)
574            .into_par_iter()
575            .map(|i| unsafe {
576                match self.vals {
577                    GroupsType::Idx(groups) => GroupsIndicator::Idx(groups.get_unchecked(i)),
578                    GroupsType::Slice { groups, .. } => {
579                        GroupsIndicator::Slice(*groups.get_unchecked(i))
580                    },
581                }
582            })
583            .drive_unindexed(consumer)
584    }
585}
586
587#[derive(Debug)]
588pub struct GroupPositions {
589    sliced: ManuallyDrop<GroupsType>,
590    // Unsliced buffer
591    original: Arc<GroupsType>,
592    offset: i64,
593    len: usize,
594}
595
596impl Clone for GroupPositions {
597    fn clone(&self) -> Self {
598        let sliced = slice_groups_inner(&self.original, self.offset, self.len);
599
600        Self {
601            sliced,
602            original: self.original.clone(),
603            offset: self.offset,
604            len: self.len,
605        }
606    }
607}
608
609impl PartialEq for GroupPositions {
610    fn eq(&self, other: &Self) -> bool {
611        self.offset == other.offset && self.len == other.len && self.sliced == other.sliced
612    }
613}
614
615impl AsRef<GroupsType> for GroupPositions {
616    fn as_ref(&self) -> &GroupsType {
617        self.sliced.deref()
618    }
619}
620
621impl Deref for GroupPositions {
622    type Target = GroupsType;
623
624    fn deref(&self) -> &Self::Target {
625        self.sliced.deref()
626    }
627}
628
629impl Default for GroupPositions {
630    fn default() -> Self {
631        GroupsType::default().into_sliceable()
632    }
633}
634
635impl GroupPositions {
636    pub fn slice(&self, offset: i64, len: usize) -> Self {
637        let offset = self.offset + offset;
638        slice_groups(
639            self.original.clone(),
640            offset,
641            // invariant that len should be in bounds, so truncate if not
642            if len > self.len { self.len } else { len },
643        )
644    }
645
646    pub fn sort(&mut self) {
647        if !self.as_ref().is_sorted_flag() {
648            let original = Arc::make_mut(&mut self.original);
649            original.sort();
650
651            self.sliced = slice_groups_inner(original, self.offset, self.len);
652        }
653    }
654
655    pub fn unroll(mut self) -> GroupPositions {
656        match self.sliced.deref_mut() {
657            GroupsType::Idx(_) => self,
658            GroupsType::Slice { rolling: false, .. } => self,
659            GroupsType::Slice {
660                groups, rolling, ..
661            } => {
662                let mut offset = 0 as IdxSize;
663                for g in groups.iter_mut() {
664                    g[0] = offset;
665                    offset += g[1];
666                }
667                *rolling = false;
668                self
669            },
670        }
671    }
672}
673
674fn slice_groups_inner(g: &GroupsType, offset: i64, len: usize) -> ManuallyDrop<GroupsType> {
675    // SAFETY:
676    // we create new `Vec`s from the sliced groups. But we wrap them in ManuallyDrop
677    // so that we never call drop on them.
678    // These groups lifetimes are bounded to the `g`. This must remain valid
679    // for the scope of the aggregation.
680    match g {
681        GroupsType::Idx(groups) => {
682            let first = unsafe {
683                let first = slice_slice(groups.first(), offset, len);
684                let ptr = first.as_ptr() as *mut _;
685                Vec::from_raw_parts(ptr, first.len(), first.len())
686            };
687
688            let all = unsafe {
689                let all = slice_slice(groups.all(), offset, len);
690                let ptr = all.as_ptr() as *mut _;
691                Vec::from_raw_parts(ptr, all.len(), all.len())
692            };
693            ManuallyDrop::new(GroupsType::Idx(GroupsIdx::new(
694                first,
695                all,
696                groups.is_sorted_flag(),
697            )))
698        },
699        GroupsType::Slice { groups, rolling } => {
700            let groups = unsafe {
701                let groups = slice_slice(groups, offset, len);
702                let ptr = groups.as_ptr() as *mut _;
703                Vec::from_raw_parts(ptr, groups.len(), groups.len())
704            };
705
706            ManuallyDrop::new(GroupsType::Slice {
707                groups,
708                rolling: *rolling,
709            })
710        },
711    }
712}
713
714fn slice_groups(g: Arc<GroupsType>, offset: i64, len: usize) -> GroupPositions {
715    let sliced = slice_groups_inner(g.as_ref(), offset, len);
716
717    GroupPositions {
718        sliced,
719        original: g,
720        offset,
721        len,
722    }
723}