Skip to main content

polars_utils/
calc_morsel_split.rs

1use std::num::NonZeroU64;
2
3#[derive(Default, Debug, Clone)]
4pub struct PartSizesIter {
5    base_part_size: u64,
6    remaining_parts: usize,
7    remainder_cutoff: usize,
8}
9
10impl PartSizesIter {
11    pub fn new_from_total_size(total_size: u64, n_parts: usize) -> Self {
12        if n_parts == 0 {
13            return Default::default();
14        }
15
16        let base_part_size = total_size / n_parts as u64;
17        let remainder = total_size % n_parts as u64;
18        let remainder_cutoff = usize::try_from(n_parts as u64 - remainder).unwrap();
19
20        Self {
21            base_part_size,
22            remaining_parts: n_parts,
23            remainder_cutoff,
24        }
25    }
26
27    pub fn new_from_part_size(part_size: u64, n_parts: usize) -> Self {
28        Self {
29            base_part_size: part_size,
30            remaining_parts: n_parts,
31            remainder_cutoff: usize::MAX,
32        }
33    }
34
35    pub fn base_part_size(&self) -> u64 {
36        self.base_part_size
37    }
38}
39
40impl Iterator for PartSizesIter {
41    type Item = u64;
42
43    fn next(&mut self) -> Option<Self::Item> {
44        self.remaining_parts = self.remaining_parts.checked_sub(1)?;
45        Some(self.base_part_size + (self.remaining_parts >= self.remainder_cutoff) as u64)
46    }
47
48    fn size_hint(&self) -> (usize, Option<usize>) {
49        let n = ExactSizeIterator::len(self);
50        (n, Some(n))
51    }
52}
53
54impl ExactSizeIterator for PartSizesIter {
55    fn len(&self) -> usize {
56        self.remaining_parts
57    }
58}
59
60/// Number of parts to split `size` to minimize the average part size difference to `target_part_size`.
61pub fn calc_n_parts(size: u64, target_part_size: NonZeroU64) -> u64 {
62    if size <= target_part_size.get() {
63        return if size == 0 { 0 } else { 1 };
64    }
65
66    let n_parts = size / target_part_size.get();
67
68    (n_parts..=n_parts.saturating_add(1))
69        .min_by_key(|n_parts| (size / *n_parts).abs_diff(target_part_size.get()))
70        .unwrap()
71}
72
73#[cfg(test)]
74mod tests {
75    use std::num::NonZeroU64;
76
77    use crate::calc_morsel_split::{PartSizesIter, calc_n_parts};
78    use crate::itertools::Itertools;
79
80    #[test]
81    fn test_calc_n_parts() {
82        let mut prev_n_parts: u64 = 0;
83
84        let boundaries = (0..1000u64)
85            .filter(|i| {
86                let n_parts = calc_n_parts(*i, const { NonZeroU64::new(100).unwrap() });
87                let changed = n_parts != prev_n_parts;
88                prev_n_parts = n_parts;
89                changed
90            })
91            .collect::<Vec<_>>();
92
93        assert_eq!(boundaries, [1, 134, 242, 345, 448, 550, 651, 752, 855, 954]);
94    }
95
96    #[test]
97    fn test_part_sizes_iter() {
98        assert_eq!(
99            PartSizesIter::new_from_total_size(0, 0).collect_vec(),
100            &[] as &[u64]
101        );
102        assert_eq!(
103            PartSizesIter::new_from_total_size(1, 0).collect_vec(),
104            &[] as &[u64]
105        );
106        assert_eq!(PartSizesIter::new_from_total_size(0, 1).collect_vec(), &[0]);
107        assert_eq!(PartSizesIter::new_from_total_size(1, 1).collect_vec(), &[1]);
108        assert_eq!(
109            PartSizesIter::new_from_total_size(1, 2).collect_vec(),
110            &[1, 0]
111        );
112        assert_eq!(PartSizesIter::new_from_total_size(2, 1).collect_vec(), &[2]);
113
114        assert_eq!(
115            PartSizesIter::new_from_total_size(100, 2).collect_vec(),
116            &[50, 50]
117        );
118        assert_eq!(
119            PartSizesIter::new_from_total_size(101, 2).collect_vec(),
120            &[51, 50]
121        );
122        assert_eq!(
123            PartSizesIter::new_from_total_size(102, 2).collect_vec(),
124            &[51, 51]
125        );
126        assert_eq!(
127            PartSizesIter::new_from_total_size(103, 2).collect_vec(),
128            &[52, 51]
129        );
130        assert_eq!(
131            PartSizesIter::new_from_total_size(104, 2).collect_vec(),
132            &[52, 52]
133        );
134    }
135}