polars_utils/
calc_morsel_split.rs1use std::num::NonZeroU64;
2
3#[derive(Default, Debug, Clone)]
4pub struct PartSizesIter {
5 base_part_size: u64,
6 remaining_parts: usize,
7 remainder_cutoff: usize,
8}
9
10impl PartSizesIter {
11 pub fn new_from_total_size(total_size: u64, n_parts: usize) -> Self {
12 if n_parts == 0 {
13 return Default::default();
14 }
15
16 let base_part_size = total_size / n_parts as u64;
17 let remainder = total_size % n_parts as u64;
18 let remainder_cutoff = usize::try_from(n_parts as u64 - remainder).unwrap();
19
20 Self {
21 base_part_size,
22 remaining_parts: n_parts,
23 remainder_cutoff,
24 }
25 }
26
27 pub fn new_from_part_size(part_size: u64, n_parts: usize) -> Self {
28 Self {
29 base_part_size: part_size,
30 remaining_parts: n_parts,
31 remainder_cutoff: usize::MAX,
32 }
33 }
34
35 pub fn base_part_size(&self) -> u64 {
36 self.base_part_size
37 }
38}
39
40impl Iterator for PartSizesIter {
41 type Item = u64;
42
43 fn next(&mut self) -> Option<Self::Item> {
44 self.remaining_parts = self.remaining_parts.checked_sub(1)?;
45 Some(self.base_part_size + (self.remaining_parts >= self.remainder_cutoff) as u64)
46 }
47
48 fn size_hint(&self) -> (usize, Option<usize>) {
49 let n = ExactSizeIterator::len(self);
50 (n, Some(n))
51 }
52}
53
54impl ExactSizeIterator for PartSizesIter {
55 fn len(&self) -> usize {
56 self.remaining_parts
57 }
58}
59
60pub fn calc_n_parts(size: u64, target_part_size: NonZeroU64) -> u64 {
62 if size <= target_part_size.get() {
63 return if size == 0 { 0 } else { 1 };
64 }
65
66 let n_parts = size / target_part_size.get();
67
68 (n_parts..=n_parts.saturating_add(1))
69 .min_by_key(|n_parts| (size / *n_parts).abs_diff(target_part_size.get()))
70 .unwrap()
71}
72
73#[cfg(test)]
74mod tests {
75 use std::num::NonZeroU64;
76
77 use crate::calc_morsel_split::{PartSizesIter, calc_n_parts};
78 use crate::itertools::Itertools;
79
80 #[test]
81 fn test_calc_n_parts() {
82 let mut prev_n_parts: u64 = 0;
83
84 let boundaries = (0..1000u64)
85 .filter(|i| {
86 let n_parts = calc_n_parts(*i, const { NonZeroU64::new(100).unwrap() });
87 let changed = n_parts != prev_n_parts;
88 prev_n_parts = n_parts;
89 changed
90 })
91 .collect::<Vec<_>>();
92
93 assert_eq!(boundaries, [1, 134, 242, 345, 448, 550, 651, 752, 855, 954]);
94 }
95
96 #[test]
97 fn test_part_sizes_iter() {
98 assert_eq!(
99 PartSizesIter::new_from_total_size(0, 0).collect_vec(),
100 &[] as &[u64]
101 );
102 assert_eq!(
103 PartSizesIter::new_from_total_size(1, 0).collect_vec(),
104 &[] as &[u64]
105 );
106 assert_eq!(PartSizesIter::new_from_total_size(0, 1).collect_vec(), &[0]);
107 assert_eq!(PartSizesIter::new_from_total_size(1, 1).collect_vec(), &[1]);
108 assert_eq!(
109 PartSizesIter::new_from_total_size(1, 2).collect_vec(),
110 &[1, 0]
111 );
112 assert_eq!(PartSizesIter::new_from_total_size(2, 1).collect_vec(), &[2]);
113
114 assert_eq!(
115 PartSizesIter::new_from_total_size(100, 2).collect_vec(),
116 &[50, 50]
117 );
118 assert_eq!(
119 PartSizesIter::new_from_total_size(101, 2).collect_vec(),
120 &[51, 50]
121 );
122 assert_eq!(
123 PartSizesIter::new_from_total_size(102, 2).collect_vec(),
124 &[51, 51]
125 );
126 assert_eq!(
127 PartSizesIter::new_from_total_size(103, 2).collect_vec(),
128 &[52, 51]
129 );
130 assert_eq!(
131 PartSizesIter::new_from_total_size(104, 2).collect_vec(),
132 &[52, 52]
133 );
134 }
135}