polars_utils/
order_statistic_tree.rs

1//! This module implements an order statistic multiset, which is implemented
2//! as a weight-balanced tree (WBT).
3//! It is based on the weight-balanced tree based on the following papers:
4//!
5//!   * <https://doi.org/10.1017/S0956796811000104>
6//!   * <https://doi.org/10.1137/1.9781611976007.13>
7//!
8//! Each of the nodes in the tree contains a UnitVec of values to store
9//! multiple values with the same key.
10
11use std::cmp::Ordering;
12use std::fmt::Debug;
13use std::ops::RangeInclusive;
14
15use slotmap::{Key as SlotMapKey, SlotMap, new_key_type};
16
17use crate::UnitVec;
18
19const DELTA: usize = 3;
20const GAMMA: usize = 2;
21
22type CompareFn<T> = fn(&T, &T) -> Ordering;
23
24new_key_type! {
25    struct Key;
26}
27
28#[derive(Debug)]
29struct Node<T> {
30    values: UnitVec<T>,
31    left: Key,
32    right: Key,
33    weight: u32,
34    num_elems: u32,
35}
36
37#[derive(Debug)]
38pub struct OrderStatisticTree<T> {
39    nodes: SlotMap<Key, Node<T>>,
40    root: Key,
41    compare: CompareFn<T>,
42}
43
44impl<T> OrderStatisticTree<T> {
45    #[inline]
46    pub fn new(compare: CompareFn<T>) -> Self {
47        OrderStatisticTree {
48            nodes: SlotMap::with_key(),
49            root: Key::null(),
50            compare,
51        }
52    }
53
54    #[inline]
55    pub fn with_capacity(capacity: usize, compare: CompareFn<T>) -> Self {
56        OrderStatisticTree {
57            nodes: SlotMap::with_capacity_and_key(capacity),
58            root: Key::null(),
59            compare,
60        }
61    }
62
63    #[inline]
64    pub fn is_empty(&self) -> bool {
65        self.len() == 0
66    }
67
68    #[inline]
69    pub fn len(&self) -> usize {
70        self.num_elems(self.root)
71    }
72
73    #[inline]
74    pub fn unique_len(&self) -> usize {
75        self.tree_weight(self.root)
76    }
77
78    #[inline]
79    pub fn clear(&mut self) {
80        self.nodes.clear();
81        self.root = Key::null();
82    }
83
84    /// Returns the total number of elements in the tree rooted at `tree`.
85    fn num_elems(&self, tree: Key) -> usize {
86        if tree.is_null() {
87            return 0;
88        }
89        unsafe { self.nodes.get_unchecked(tree) }.num_elems as usize
90    }
91
92    /// Returns the number of tree nodes, which is equal to the number of unique
93    /// elements, in the tree rooted at `tree`.
94    fn tree_weight(&self, tree: Key) -> usize {
95        if tree.is_null() {
96            return 0;
97        }
98        unsafe { self.nodes.get_unchecked(tree) }.weight as usize
99    }
100
101    #[must_use]
102    fn new_tree_node(&mut self, left: Key, values: UnitVec<T>, right: Key) -> Key {
103        let weight = self.tree_weight(left) + self.tree_weight(right) + 1;
104        let num_elems = self.num_elems(left) + self.num_elems(right) + values.len();
105        let n = Node {
106            values,
107            left,
108            right,
109            weight: weight as u32,
110            num_elems: num_elems as u32,
111        };
112        self.nodes.insert(n)
113    }
114
115    #[must_use]
116    fn new_leaf(&mut self, value: T) -> Key {
117        let mut uv = UnitVec::new();
118        uv.push(value);
119        self.new_tree_node(Key::null(), uv, Key::null())
120    }
121
122    #[must_use]
123    unsafe fn drop_tree_node(&mut self, tree: Key) -> Node<T> {
124        unsafe { self.nodes.remove(tree).unwrap_unchecked() }
125    }
126
127    #[inline]
128    pub fn get(&self, idx: usize) -> Option<&T> {
129        self._get(idx, self.root)
130    }
131
132    fn _get(&self, idx: usize, tree: Key) -> Option<&T> {
133        if tree.is_null() {
134            return None;
135        }
136
137        let n = unsafe { self.nodes.get_unchecked(tree) };
138        let own_elems = self.num_elems(tree);
139        let left_elems = self.num_elems(n.left);
140        let right_elems = self.num_elems(n.right);
141
142        if idx < left_elems {
143            self._get(idx, n.left)
144        } else if idx >= own_elems - right_elems {
145            self._get(idx - (own_elems - right_elems), n.right)
146        } else {
147            n.values.get(idx - left_elems)
148        }
149    }
150
151    #[inline]
152    pub fn insert(&mut self, value: T) {
153        (self.root, _) = self._insert(value, self.root);
154    }
155
156    #[must_use]
157    fn _insert(&mut self, value: T, tree: Key) -> (Key, bool) {
158        if tree.is_null() {
159            return (self.new_leaf(value), true);
160        }
161
162        let n = unsafe { self.nodes.get_unchecked(tree) };
163        match (self.compare)(&value, &n.values[0]) {
164            Ordering::Less => {
165                let (left, node_added) = self._insert(value, n.left);
166                let n = unsafe { self.nodes.get_unchecked_mut(tree) };
167                n.left = left;
168                n.weight += node_added as u32;
169                n.num_elems += 1;
170                (self.balance_r(tree), node_added)
171            },
172            Ordering::Equal => {
173                let n = unsafe { self.nodes.get_unchecked_mut(tree) };
174                n.values.push(value);
175                n.num_elems += 1;
176                (tree, false)
177            },
178            Ordering::Greater => {
179                let (right, node_added) = self._insert(value, n.right);
180                let n = unsafe { self.nodes.get_unchecked_mut(tree) };
181                n.right = right;
182                n.weight += node_added as u32;
183                n.num_elems += 1;
184                (self.balance_l(tree), node_added)
185            },
186        }
187    }
188
189    #[inline]
190    pub fn remove(&mut self, value: &T) -> Option<T> {
191        let deleted;
192        (deleted, self.root, _) = self._remove(value, self.root);
193        deleted
194    }
195
196    #[must_use]
197    fn _remove(&mut self, value: &T, tree: Key) -> (Option<T>, Key, bool) {
198        if tree.is_null() {
199            return (None, tree, false);
200        }
201
202        let n = unsafe { self.nodes.get_unchecked(tree) };
203        match (self.compare)(value, &n.values[0]) {
204            Ordering::Less => {
205                let (deleted, left, node_removed) = self._remove(value, n.left);
206                let n = unsafe { self.nodes.get_unchecked_mut(tree) };
207                n.left = left;
208                n.weight -= node_removed as u32;
209                n.num_elems -= deleted.is_some() as u32;
210                (deleted, self.balance_l(tree), node_removed)
211            },
212            Ordering::Greater => {
213                let (deleted, right, node_removed) = self._remove(value, n.right);
214                let n = unsafe { self.nodes.get_unchecked_mut(tree) };
215                n.right = right;
216                n.weight -= node_removed as u32;
217                n.num_elems -= deleted.is_some() as u32;
218                (deleted, self.balance_r(tree), node_removed)
219            },
220            Ordering::Equal if n.values.len() > 1 => {
221                let n = unsafe { self.nodes.get_unchecked_mut(tree) };
222                let popped_value = unsafe { n.values.pop().unwrap_unchecked() };
223                n.num_elems -= 1;
224                (Some(popped_value), tree, false)
225            },
226            Ordering::Equal => {
227                let mut n = unsafe { self.drop_tree_node(tree) };
228                (
229                    Some(unsafe { n.values.pop().unwrap_unchecked() }),
230                    self.glue(n.left, n.right),
231                    true,
232                )
233            },
234        }
235    }
236
237    #[must_use]
238    fn glue(&mut self, left: Key, right: Key) -> Key {
239        if left.is_null() {
240            right
241        } else if right.is_null() {
242            left
243        } else if self.tree_weight(left) > self.tree_weight(right) {
244            let (deleted, left) = self.remove_max(left);
245            let tree = self.new_tree_node(left, deleted, right);
246            self.balance_r(tree)
247        } else {
248            let (deleted, right) = self.remove_min(right);
249            let tree = self.new_tree_node(left, deleted, right);
250            self.balance_l(tree)
251        }
252    }
253
254    #[must_use]
255    fn remove_min(&mut self, tree: Key) -> (UnitVec<T>, Key) {
256        debug_assert!(!tree.is_null());
257        let n = unsafe { self.nodes.get_unchecked(tree) };
258        if n.left.is_null() {
259            let n = unsafe { self.drop_tree_node(tree) };
260            return (n.values, n.right);
261        }
262        let (deleted, left) = self.remove_min(n.left);
263        let n = unsafe { self.nodes.get_unchecked_mut(tree) };
264        n.left = left;
265        n.weight -= 1;
266        n.num_elems -= deleted.len() as u32;
267        (deleted, self.balance_l(tree))
268    }
269
270    #[must_use]
271    fn remove_max(&mut self, tree: Key) -> (UnitVec<T>, Key) {
272        debug_assert!(!tree.is_null());
273        let n = unsafe { self.nodes.get_unchecked(tree) };
274        if n.right.is_null() {
275            let n = unsafe { self.drop_tree_node(tree) };
276            return (n.values, n.left);
277        }
278        let (deleted, right) = self.remove_max(n.right);
279        let n = unsafe { self.nodes.get_unchecked_mut(tree) };
280        n.right = right;
281        n.weight -= 1;
282        n.num_elems -= deleted.len() as u32;
283        (deleted, self.balance_r(tree))
284    }
285
286    #[inline]
287    pub fn contains(&self, value: &T) -> bool {
288        self._contains(value, self.root)
289    }
290
291    fn _contains(&self, value: &T, tree: Key) -> bool {
292        if tree.is_null() {
293            return false;
294        }
295        let n = unsafe { self.nodes.get_unchecked(tree) };
296        match (self.compare)(value, &n.values[0]) {
297            Ordering::Less => self._contains(value, n.left),
298            Ordering::Equal => true,
299            Ordering::Greater => self._contains(value, n.right),
300        }
301    }
302
303    #[must_use]
304    fn balance_l(&mut self, tree: Key) -> Key {
305        let n = unsafe { self.nodes.get_unchecked(tree) };
306        if self.pair_is_balanced(n.left, n.right) {
307            return tree;
308        }
309        self.rotate_l(tree)
310    }
311
312    #[must_use]
313    fn rotate_l(&mut self, tree: Key) -> Key {
314        let n = unsafe { self.nodes.get_unchecked(tree) };
315        let r = unsafe { self.nodes.get_unchecked(n.right) };
316        if self.is_single(r.left, r.right) {
317            self.single_l(tree)
318        } else {
319            self.double_l(tree)
320        }
321    }
322
323    #[must_use]
324    fn single_l(&mut self, tree: Key) -> Key {
325        let n = unsafe { self.drop_tree_node(tree) };
326        let r = unsafe { self.drop_tree_node(n.right) };
327        let new_left = self.new_tree_node(n.left, n.values, r.left);
328        self.new_tree_node(new_left, r.values, r.right)
329    }
330
331    #[must_use]
332    fn double_l(&mut self, tree: Key) -> Key {
333        let n = unsafe { self.drop_tree_node(tree) };
334        let r = unsafe { self.drop_tree_node(n.right) };
335        let rl = unsafe { self.drop_tree_node(r.left) };
336        let new_left = self.new_tree_node(n.left, n.values, rl.left);
337        let new_right = self.new_tree_node(rl.right, r.values, r.right);
338        self.new_tree_node(new_left, rl.values, new_right)
339    }
340
341    #[must_use]
342    fn balance_r(&mut self, tree: Key) -> Key {
343        let n = unsafe { self.nodes.get_unchecked(tree) };
344        if self.pair_is_balanced(n.right, n.left) {
345            return tree;
346        }
347        self.rotate_r(tree)
348    }
349
350    #[must_use]
351    fn rotate_r(&mut self, tree: Key) -> Key {
352        let n = unsafe { self.nodes.get_unchecked(tree) };
353        let l = unsafe { self.nodes.get_unchecked(n.left) };
354        if self.is_single(l.right, l.left) {
355            self.single_r(tree)
356        } else {
357            self.double_r(tree)
358        }
359    }
360
361    #[must_use]
362    fn single_r(&mut self, tree: Key) -> Key {
363        let n = unsafe { self.drop_tree_node(tree) };
364        let l = unsafe { self.drop_tree_node(n.left) };
365        let new_right = self.new_tree_node(l.right, n.values, n.right);
366        self.new_tree_node(l.left, l.values, new_right)
367    }
368
369    #[must_use]
370    fn double_r(&mut self, tree: Key) -> Key {
371        let n = unsafe { self.drop_tree_node(tree) };
372        let l = unsafe { self.drop_tree_node(n.left) };
373        let lr = unsafe { self.drop_tree_node(l.right) };
374        let new_right = self.new_tree_node(lr.right, n.values, n.right);
375        let new_left = self.new_tree_node(l.left, l.values, lr.left);
376        self.new_tree_node(new_left, lr.values, new_right)
377    }
378
379    #[doc(hidden)]
380    pub fn is_balanced(&self) -> bool {
381        self.tree_is_balanced(self.root)
382    }
383
384    fn tree_is_balanced(&self, tree: Key) -> bool {
385        if tree.is_null() {
386            return true;
387        }
388        let n = unsafe { self.nodes.get_unchecked(tree) };
389        self.pair_is_balanced(n.left, n.right)
390            && self.pair_is_balanced(n.right, n.left)
391            && self.tree_is_balanced(n.left)
392            && self.tree_is_balanced(n.right)
393    }
394
395    fn pair_is_balanced(&self, left: Key, right: Key) -> bool {
396        let a = self.tree_weight(left);
397        let b = self.tree_weight(right);
398        DELTA * (a + 1) >= (b + 1) && DELTA * (b + 1) >= (a + 1)
399    }
400
401    fn is_single(&self, left: Key, right: Key) -> bool {
402        let a = self.tree_weight(left);
403        let b = self.tree_weight(right);
404        a + 1 < GAMMA * (b + 1)
405    }
406
407    #[inline]
408    pub fn rank_range(&self, bound: &T) -> Result<RangeInclusive<usize>, usize> {
409        self._rank_range(bound, self.root)
410    }
411
412    fn _rank_range(&self, value: &T, tree: Key) -> Result<RangeInclusive<usize>, usize> {
413        if tree.is_null() {
414            return Err(0);
415        }
416        let n = unsafe { self.nodes.get_unchecked(tree) };
417        match (self.compare)(value, &n.values[0]) {
418            Ordering::Less => self._rank_range(value, n.left),
419            Ordering::Equal => {
420                let lo = self.num_elems(n.left);
421                let hi = lo + n.values.len() - 1;
422                Ok(lo..=hi)
423            },
424            Ordering::Greater => {
425                let update_rank = |r| self.num_elems(tree) - self.num_elems(n.right) + r;
426                self._rank_range(value, n.right)
427                    .map(|rank| update_rank(*rank.start())..=update_rank(*rank.end()))
428                    .map_err(update_rank)
429            },
430        }
431    }
432
433    #[inline]
434    pub fn rank_unique(&self, value: &T) -> Result<usize, usize> {
435        self._rank_unique(value, self.root)
436    }
437
438    fn _rank_unique(&self, value: &T, tree: Key) -> Result<usize, usize> {
439        if tree.is_null() {
440            return Err(0);
441        }
442        let n = unsafe { self.nodes.get_unchecked(tree) };
443        match (self.compare)(value, &n.values[0]) {
444            Ordering::Less => self._rank_unique(value, n.left),
445            Ordering::Equal => Ok(self.tree_weight(n.left)),
446            Ordering::Greater => self
447                ._rank_unique(value, n.right)
448                .map(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank)
449                .map_err(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank),
450        }
451    }
452
453    #[inline]
454    pub fn count(&self, value: &T) -> usize {
455        self._count(value, self.root)
456    }
457
458    fn _count(&self, value: &T, tree: Key) -> usize {
459        if tree.is_null() {
460            return 0;
461        }
462        let n = unsafe { self.nodes.get_unchecked(tree) };
463        match (self.compare)(value, &n.values[0]) {
464            Ordering::Less => self._count(value, n.left),
465            Ordering::Equal => n.values.len(),
466            Ordering::Greater => self._count(value, n.right),
467        }
468    }
469}
470
471impl<T> Extend<T> for OrderStatisticTree<T> {
472    fn extend<I: IntoIterator<Item = T>>(&mut self, iterable: I) {
473        let iterator = iterable.into_iter();
474        for element in iterator {
475            self.insert(element);
476        }
477    }
478}
479
480#[cfg(test)]
481mod test {
482
483    use proptest::collection::vec;
484    use proptest::prelude::*;
485    use proptest::test_runner::TestRunner;
486
487    use super::*;
488
489    #[test]
490    fn test_insert() {
491        let mut runner = TestRunner::default();
492        runner
493            .run(&vec((0i32..100, 0i32..100), 0..100), test_insert_inner)
494            .unwrap()
495    }
496
497    fn test_insert_inner(items: Vec<(i32, i32)>) -> Result<(), TestCaseError> {
498        let cmp = |a: &(i32, i32), b: &(i32, i32)| i32::cmp(&a.0, &b.0);
499        let mut ost = OrderStatisticTree::new(cmp);
500        for item in &items {
501            ost.insert(*item);
502            assert!(ost.is_balanced());
503        }
504        assert_eq!(ost.len(), items.len());
505        let mut sorted_items = items.clone();
506        sorted_items.sort();
507        let mut collected_items = Vec::new();
508        let mut i = 0;
509        while let Some(v) = ost.get(i) {
510            collected_items.push(*v);
511            i += 1;
512        }
513        collected_items.sort();
514        assert_eq!(ost.len(), items.len());
515        assert_eq!(&collected_items, &sorted_items);
516        Ok(())
517    }
518
519    #[test]
520    fn test_remove() {
521        let mut runner = TestRunner::default();
522        runner
523            .run(
524                &(vec(0i32..100, 0..100), vec(0i32..100, 0..100)),
525                test_remove_inner,
526            )
527            .unwrap();
528    }
529
530    fn test_remove_inner(input: (Vec<i32>, Vec<i32>)) -> Result<(), TestCaseError> {
531        let (mut items, to_remove) = input;
532        let mut ost = OrderStatisticTree::new(i32::cmp);
533        for item in &items {
534            ost.insert(*item);
535            assert!(ost.is_balanced());
536        }
537        items.sort();
538        for item in &to_remove {
539            let v = ost.remove(item);
540            assert!(ost.is_balanced());
541            let idx = items.binary_search(item);
542            assert_eq!(v.is_some(), idx.is_ok());
543            if let Ok(idx) = idx {
544                items.remove(idx);
545            }
546            assert_eq!(ost.len(), items.len());
547        }
548        assert_eq!(ost.len(), items.len());
549        for item in 0..100 {
550            assert_eq!(ost.contains(&item), items.contains(&item));
551        }
552        Ok(())
553    }
554
555    #[test]
556    fn test_rank() {
557        let mut runner = TestRunner::default();
558        runner
559            .run(&vec(0i32..100, 0..100), test_rank_inner)
560            .unwrap();
561    }
562
563    fn test_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {
564        let mut ost = OrderStatisticTree::new(i32::cmp);
565        for item in &items {
566            ost.insert(*item);
567        }
568        items.sort();
569        for item in 0..100 {
570            let rank = ost.rank_range(&item);
571
572            let expected_rank = if items.contains(&item) {
573                let expected_rank_lower = items.iter().filter(|&x| *x < item).count();
574                let expected_rank_upper = items.iter().filter(|&x| *x <= item).count() - 1;
575                Ok(expected_rank_lower..=expected_rank_upper)
576            } else {
577                Err(items.iter().filter(|&x| *x < item).count())
578            };
579
580            assert_eq!(rank, expected_rank);
581        }
582        Ok(())
583    }
584
585    #[test]
586    fn test_unique_rank() {
587        let mut runner = TestRunner::default();
588        runner
589            .run(&vec(0i32..50, 0..100), test_unique_rank_inner)
590            .unwrap();
591    }
592
593    fn test_unique_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {
594        let mut ost = OrderStatisticTree::new(i32::cmp);
595        for item in &items {
596            ost.insert(*item);
597        }
598        assert_eq!(ost.len(), items.len());
599        items.sort();
600        items.dedup();
601        assert_eq!(ost.unique_len(), items.len());
602        for item in 0..50 {
603            let unique_rank = ost.rank_unique(&item);
604            let expected_unique_rank = if items.contains(&item) {
605                Ok(items.iter().filter(|&x| *x < item).count())
606            } else {
607                Err(items.iter().filter(|&x| *x < item).count())
608            };
609            assert_eq!(unique_rank, expected_unique_rank);
610        }
611        Ok(())
612    }
613
614    #[test]
615    fn test_empty() {
616        let ost = OrderStatisticTree::<i32>::new(i32::cmp);
617        assert!(ost.is_empty());
618        assert_eq!(ost.len(), 0);
619        assert_eq!(ost.unique_len(), 0);
620        assert!(ost.is_balanced());
621        assert!(!ost.contains(&1));
622        assert_eq!(ost.rank_range(&1), Err(0));
623        assert_eq!(ost.rank_unique(&1), Err(0));
624    }
625
626    #[test]
627    fn test_clear() {
628        let mut ost = OrderStatisticTree::new(i32::cmp);
629        for item in 0..10 {
630            ost.insert(item);
631        }
632        assert_eq!(ost.len(), 10);
633        assert_eq!(ost.unique_len(), 10);
634        ost.clear();
635        assert!(ost.is_empty());
636    }
637
638    #[test]
639    fn test_extend() {
640        let mut ost = OrderStatisticTree::new(i32::cmp);
641        ost.extend(0..10);
642        assert_eq!(ost.len(), 10);
643        assert_eq!(ost.unique_len(), 10);
644        for item in 0..10 {
645            assert!(ost.contains(&item));
646        }
647    }
648
649    #[test]
650    fn test_count() {
651        let mut ost = OrderStatisticTree::new(i32::cmp);
652        for item in &[1, 2, 2, 3, 3, 3] {
653            ost.insert(*item);
654        }
655        assert_eq!(ost.count(&1), 1);
656        assert_eq!(ost.count(&2), 2);
657        assert_eq!(ost.count(&3), 3);
658        assert_eq!(ost.count(&4), 0);
659    }
660
661    #[test]
662    fn test_get() {
663        let mut ost = OrderStatisticTree::new(i32::cmp);
664        let mut items = [3, 1, 4, 1, 5, 9, 2, 6, 5];
665        for item in items {
666            ost.insert(item);
667        }
668        items.sort();
669        for (i, item) in items.iter().enumerate() {
670            assert_eq!(ost.get(i), Some(item));
671        }
672        assert_eq!(ost.get(items.len()), None);
673    }
674}