1use std::cmp::Ordering;
12use std::fmt::Debug;
13use std::ops::RangeInclusive;
14
15use slotmap::{Key as SlotMapKey, SlotMap, new_key_type};
16
17use crate::UnitVec;
18
19const DELTA: usize = 3;
20const GAMMA: usize = 2;
21
22type CompareFn<T> = fn(&T, &T) -> Ordering;
23
24new_key_type! {
25 struct Key;
26}
27
28#[derive(Debug)]
29struct Node<T> {
30 values: UnitVec<T>,
31 left: Key,
32 right: Key,
33 weight: u32,
34 num_elems: u32,
35}
36
37#[derive(Debug)]
38pub struct OrderStatisticTree<T> {
39 nodes: SlotMap<Key, Node<T>>,
40 root: Key,
41 compare: CompareFn<T>,
42}
43
44impl<T> OrderStatisticTree<T> {
45 #[inline]
46 pub fn new(compare: CompareFn<T>) -> Self {
47 OrderStatisticTree {
48 nodes: SlotMap::with_key(),
49 root: Key::null(),
50 compare,
51 }
52 }
53
54 #[inline]
55 pub fn with_capacity(capacity: usize, compare: CompareFn<T>) -> Self {
56 OrderStatisticTree {
57 nodes: SlotMap::with_capacity_and_key(capacity),
58 root: Key::null(),
59 compare,
60 }
61 }
62
63 #[inline]
64 pub fn is_empty(&self) -> bool {
65 self.len() == 0
66 }
67
68 #[inline]
69 pub fn len(&self) -> usize {
70 self.num_elems(self.root)
71 }
72
73 #[inline]
74 pub fn unique_len(&self) -> usize {
75 self.tree_weight(self.root)
76 }
77
78 #[inline]
79 pub fn clear(&mut self) {
80 self.nodes.clear();
81 self.root = Key::null();
82 }
83
84 fn num_elems(&self, tree: Key) -> usize {
86 if tree.is_null() {
87 return 0;
88 }
89 unsafe { self.nodes.get_unchecked(tree) }.num_elems as usize
90 }
91
92 fn tree_weight(&self, tree: Key) -> usize {
95 if tree.is_null() {
96 return 0;
97 }
98 unsafe { self.nodes.get_unchecked(tree) }.weight as usize
99 }
100
101 #[must_use]
102 fn new_tree_node(&mut self, left: Key, values: UnitVec<T>, right: Key) -> Key {
103 let weight = self.tree_weight(left) + self.tree_weight(right) + 1;
104 let num_elems = self.num_elems(left) + self.num_elems(right) + values.len();
105 let n = Node {
106 values,
107 left,
108 right,
109 weight: weight as u32,
110 num_elems: num_elems as u32,
111 };
112 self.nodes.insert(n)
113 }
114
115 #[must_use]
116 fn new_leaf(&mut self, value: T) -> Key {
117 let mut uv = UnitVec::new();
118 uv.push(value);
119 self.new_tree_node(Key::null(), uv, Key::null())
120 }
121
122 #[must_use]
123 unsafe fn drop_tree_node(&mut self, tree: Key) -> Node<T> {
124 unsafe { self.nodes.remove(tree).unwrap_unchecked() }
125 }
126
127 #[inline]
128 pub fn get(&self, idx: usize) -> Option<&T> {
129 self._get(idx, self.root)
130 }
131
132 fn _get(&self, idx: usize, tree: Key) -> Option<&T> {
133 if tree.is_null() {
134 return None;
135 }
136
137 let n = unsafe { self.nodes.get_unchecked(tree) };
138 let own_elems = self.num_elems(tree);
139 let left_elems = self.num_elems(n.left);
140 let right_elems = self.num_elems(n.right);
141
142 if idx < left_elems {
143 self._get(idx, n.left)
144 } else if idx >= own_elems - right_elems {
145 self._get(idx - (own_elems - right_elems), n.right)
146 } else {
147 n.values.get(idx - left_elems)
148 }
149 }
150
151 #[inline]
152 pub fn insert(&mut self, value: T) {
153 (self.root, _) = self._insert(value, self.root);
154 }
155
156 #[must_use]
157 fn _insert(&mut self, value: T, tree: Key) -> (Key, bool) {
158 if tree.is_null() {
159 return (self.new_leaf(value), true);
160 }
161
162 let n = unsafe { self.nodes.get_unchecked(tree) };
163 match (self.compare)(&value, &n.values[0]) {
164 Ordering::Less => {
165 let (left, node_added) = self._insert(value, n.left);
166 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
167 n.left = left;
168 n.weight += node_added as u32;
169 n.num_elems += 1;
170 (self.balance_r(tree), node_added)
171 },
172 Ordering::Equal => {
173 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
174 n.values.push(value);
175 n.num_elems += 1;
176 (tree, false)
177 },
178 Ordering::Greater => {
179 let (right, node_added) = self._insert(value, n.right);
180 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
181 n.right = right;
182 n.weight += node_added as u32;
183 n.num_elems += 1;
184 (self.balance_l(tree), node_added)
185 },
186 }
187 }
188
189 #[inline]
190 pub fn remove(&mut self, value: &T) -> Option<T> {
191 let deleted;
192 (deleted, self.root, _) = self._remove(value, self.root);
193 deleted
194 }
195
196 #[must_use]
197 fn _remove(&mut self, value: &T, tree: Key) -> (Option<T>, Key, bool) {
198 if tree.is_null() {
199 return (None, tree, false);
200 }
201
202 let n = unsafe { self.nodes.get_unchecked(tree) };
203 match (self.compare)(value, &n.values[0]) {
204 Ordering::Less => {
205 let (deleted, left, node_removed) = self._remove(value, n.left);
206 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
207 n.left = left;
208 n.weight -= node_removed as u32;
209 n.num_elems -= deleted.is_some() as u32;
210 (deleted, self.balance_l(tree), node_removed)
211 },
212 Ordering::Greater => {
213 let (deleted, right, node_removed) = self._remove(value, n.right);
214 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
215 n.right = right;
216 n.weight -= node_removed as u32;
217 n.num_elems -= deleted.is_some() as u32;
218 (deleted, self.balance_r(tree), node_removed)
219 },
220 Ordering::Equal if n.values.len() > 1 => {
221 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
222 let popped_value = unsafe { n.values.pop().unwrap_unchecked() };
223 n.num_elems -= 1;
224 (Some(popped_value), tree, false)
225 },
226 Ordering::Equal => {
227 let mut n = unsafe { self.drop_tree_node(tree) };
228 (
229 Some(unsafe { n.values.pop().unwrap_unchecked() }),
230 self.glue(n.left, n.right),
231 true,
232 )
233 },
234 }
235 }
236
237 #[must_use]
238 fn glue(&mut self, left: Key, right: Key) -> Key {
239 if left.is_null() {
240 right
241 } else if right.is_null() {
242 left
243 } else if self.tree_weight(left) > self.tree_weight(right) {
244 let (deleted, left) = self.remove_max(left);
245 let tree = self.new_tree_node(left, deleted, right);
246 self.balance_r(tree)
247 } else {
248 let (deleted, right) = self.remove_min(right);
249 let tree = self.new_tree_node(left, deleted, right);
250 self.balance_l(tree)
251 }
252 }
253
254 #[must_use]
255 fn remove_min(&mut self, tree: Key) -> (UnitVec<T>, Key) {
256 debug_assert!(!tree.is_null());
257 let n = unsafe { self.nodes.get_unchecked(tree) };
258 if n.left.is_null() {
259 let n = unsafe { self.drop_tree_node(tree) };
260 return (n.values, n.right);
261 }
262 let (deleted, left) = self.remove_min(n.left);
263 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
264 n.left = left;
265 n.weight -= 1;
266 n.num_elems -= deleted.len() as u32;
267 (deleted, self.balance_l(tree))
268 }
269
270 #[must_use]
271 fn remove_max(&mut self, tree: Key) -> (UnitVec<T>, Key) {
272 debug_assert!(!tree.is_null());
273 let n = unsafe { self.nodes.get_unchecked(tree) };
274 if n.right.is_null() {
275 let n = unsafe { self.drop_tree_node(tree) };
276 return (n.values, n.left);
277 }
278 let (deleted, right) = self.remove_max(n.right);
279 let n = unsafe { self.nodes.get_unchecked_mut(tree) };
280 n.right = right;
281 n.weight -= 1;
282 n.num_elems -= deleted.len() as u32;
283 (deleted, self.balance_r(tree))
284 }
285
286 #[inline]
287 pub fn contains(&self, value: &T) -> bool {
288 self._contains(value, self.root)
289 }
290
291 fn _contains(&self, value: &T, tree: Key) -> bool {
292 if tree.is_null() {
293 return false;
294 }
295 let n = unsafe { self.nodes.get_unchecked(tree) };
296 match (self.compare)(value, &n.values[0]) {
297 Ordering::Less => self._contains(value, n.left),
298 Ordering::Equal => true,
299 Ordering::Greater => self._contains(value, n.right),
300 }
301 }
302
303 #[must_use]
304 fn balance_l(&mut self, tree: Key) -> Key {
305 let n = unsafe { self.nodes.get_unchecked(tree) };
306 if self.pair_is_balanced(n.left, n.right) {
307 return tree;
308 }
309 self.rotate_l(tree)
310 }
311
312 #[must_use]
313 fn rotate_l(&mut self, tree: Key) -> Key {
314 let n = unsafe { self.nodes.get_unchecked(tree) };
315 let r = unsafe { self.nodes.get_unchecked(n.right) };
316 if self.is_single(r.left, r.right) {
317 self.single_l(tree)
318 } else {
319 self.double_l(tree)
320 }
321 }
322
323 #[must_use]
324 fn single_l(&mut self, tree: Key) -> Key {
325 let n = unsafe { self.drop_tree_node(tree) };
326 let r = unsafe { self.drop_tree_node(n.right) };
327 let new_left = self.new_tree_node(n.left, n.values, r.left);
328 self.new_tree_node(new_left, r.values, r.right)
329 }
330
331 #[must_use]
332 fn double_l(&mut self, tree: Key) -> Key {
333 let n = unsafe { self.drop_tree_node(tree) };
334 let r = unsafe { self.drop_tree_node(n.right) };
335 let rl = unsafe { self.drop_tree_node(r.left) };
336 let new_left = self.new_tree_node(n.left, n.values, rl.left);
337 let new_right = self.new_tree_node(rl.right, r.values, r.right);
338 self.new_tree_node(new_left, rl.values, new_right)
339 }
340
341 #[must_use]
342 fn balance_r(&mut self, tree: Key) -> Key {
343 let n = unsafe { self.nodes.get_unchecked(tree) };
344 if self.pair_is_balanced(n.right, n.left) {
345 return tree;
346 }
347 self.rotate_r(tree)
348 }
349
350 #[must_use]
351 fn rotate_r(&mut self, tree: Key) -> Key {
352 let n = unsafe { self.nodes.get_unchecked(tree) };
353 let l = unsafe { self.nodes.get_unchecked(n.left) };
354 if self.is_single(l.right, l.left) {
355 self.single_r(tree)
356 } else {
357 self.double_r(tree)
358 }
359 }
360
361 #[must_use]
362 fn single_r(&mut self, tree: Key) -> Key {
363 let n = unsafe { self.drop_tree_node(tree) };
364 let l = unsafe { self.drop_tree_node(n.left) };
365 let new_right = self.new_tree_node(l.right, n.values, n.right);
366 self.new_tree_node(l.left, l.values, new_right)
367 }
368
369 #[must_use]
370 fn double_r(&mut self, tree: Key) -> Key {
371 let n = unsafe { self.drop_tree_node(tree) };
372 let l = unsafe { self.drop_tree_node(n.left) };
373 let lr = unsafe { self.drop_tree_node(l.right) };
374 let new_right = self.new_tree_node(lr.right, n.values, n.right);
375 let new_left = self.new_tree_node(l.left, l.values, lr.left);
376 self.new_tree_node(new_left, lr.values, new_right)
377 }
378
379 #[doc(hidden)]
380 pub fn is_balanced(&self) -> bool {
381 self.tree_is_balanced(self.root)
382 }
383
384 fn tree_is_balanced(&self, tree: Key) -> bool {
385 if tree.is_null() {
386 return true;
387 }
388 let n = unsafe { self.nodes.get_unchecked(tree) };
389 self.pair_is_balanced(n.left, n.right)
390 && self.pair_is_balanced(n.right, n.left)
391 && self.tree_is_balanced(n.left)
392 && self.tree_is_balanced(n.right)
393 }
394
395 fn pair_is_balanced(&self, left: Key, right: Key) -> bool {
396 let a = self.tree_weight(left);
397 let b = self.tree_weight(right);
398 DELTA * (a + 1) >= (b + 1) && DELTA * (b + 1) >= (a + 1)
399 }
400
401 fn is_single(&self, left: Key, right: Key) -> bool {
402 let a = self.tree_weight(left);
403 let b = self.tree_weight(right);
404 a + 1 < GAMMA * (b + 1)
405 }
406
407 #[inline]
408 pub fn rank_range(&self, bound: &T) -> Result<RangeInclusive<usize>, usize> {
409 self._rank_range(bound, self.root)
410 }
411
412 fn _rank_range(&self, value: &T, tree: Key) -> Result<RangeInclusive<usize>, usize> {
413 if tree.is_null() {
414 return Err(0);
415 }
416 let n = unsafe { self.nodes.get_unchecked(tree) };
417 match (self.compare)(value, &n.values[0]) {
418 Ordering::Less => self._rank_range(value, n.left),
419 Ordering::Equal => {
420 let lo = self.num_elems(n.left);
421 let hi = lo + n.values.len() - 1;
422 Ok(lo..=hi)
423 },
424 Ordering::Greater => {
425 let update_rank = |r| self.num_elems(tree) - self.num_elems(n.right) + r;
426 self._rank_range(value, n.right)
427 .map(|rank| update_rank(*rank.start())..=update_rank(*rank.end()))
428 .map_err(update_rank)
429 },
430 }
431 }
432
433 #[inline]
434 pub fn rank_unique(&self, value: &T) -> Result<usize, usize> {
435 self._rank_unique(value, self.root)
436 }
437
438 fn _rank_unique(&self, value: &T, tree: Key) -> Result<usize, usize> {
439 if tree.is_null() {
440 return Err(0);
441 }
442 let n = unsafe { self.nodes.get_unchecked(tree) };
443 match (self.compare)(value, &n.values[0]) {
444 Ordering::Less => self._rank_unique(value, n.left),
445 Ordering::Equal => Ok(self.tree_weight(n.left)),
446 Ordering::Greater => self
447 ._rank_unique(value, n.right)
448 .map(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank)
449 .map_err(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank),
450 }
451 }
452
453 #[inline]
454 pub fn count(&self, value: &T) -> usize {
455 self._count(value, self.root)
456 }
457
458 fn _count(&self, value: &T, tree: Key) -> usize {
459 if tree.is_null() {
460 return 0;
461 }
462 let n = unsafe { self.nodes.get_unchecked(tree) };
463 match (self.compare)(value, &n.values[0]) {
464 Ordering::Less => self._count(value, n.left),
465 Ordering::Equal => n.values.len(),
466 Ordering::Greater => self._count(value, n.right),
467 }
468 }
469}
470
471impl<T> Extend<T> for OrderStatisticTree<T> {
472 fn extend<I: IntoIterator<Item = T>>(&mut self, iterable: I) {
473 let iterator = iterable.into_iter();
474 for element in iterator {
475 self.insert(element);
476 }
477 }
478}
479
480#[cfg(test)]
481mod test {
482
483 use proptest::collection::vec;
484 use proptest::prelude::*;
485 use proptest::test_runner::TestRunner;
486
487 use super::*;
488
489 #[test]
490 fn test_insert() {
491 let mut runner = TestRunner::default();
492 runner
493 .run(&vec((0i32..100, 0i32..100), 0..100), test_insert_inner)
494 .unwrap()
495 }
496
497 fn test_insert_inner(items: Vec<(i32, i32)>) -> Result<(), TestCaseError> {
498 let cmp = |a: &(i32, i32), b: &(i32, i32)| i32::cmp(&a.0, &b.0);
499 let mut ost = OrderStatisticTree::new(cmp);
500 for item in &items {
501 ost.insert(*item);
502 assert!(ost.is_balanced());
503 }
504 assert_eq!(ost.len(), items.len());
505 let mut sorted_items = items.clone();
506 sorted_items.sort();
507 let mut collected_items = Vec::new();
508 let mut i = 0;
509 while let Some(v) = ost.get(i) {
510 collected_items.push(*v);
511 i += 1;
512 }
513 collected_items.sort();
514 assert_eq!(ost.len(), items.len());
515 assert_eq!(&collected_items, &sorted_items);
516 Ok(())
517 }
518
519 #[test]
520 fn test_remove() {
521 let mut runner = TestRunner::default();
522 runner
523 .run(
524 &(vec(0i32..100, 0..100), vec(0i32..100, 0..100)),
525 test_remove_inner,
526 )
527 .unwrap();
528 }
529
530 fn test_remove_inner(input: (Vec<i32>, Vec<i32>)) -> Result<(), TestCaseError> {
531 let (mut items, to_remove) = input;
532 let mut ost = OrderStatisticTree::new(i32::cmp);
533 for item in &items {
534 ost.insert(*item);
535 assert!(ost.is_balanced());
536 }
537 items.sort();
538 for item in &to_remove {
539 let v = ost.remove(item);
540 assert!(ost.is_balanced());
541 let idx = items.binary_search(item);
542 assert_eq!(v.is_some(), idx.is_ok());
543 if let Ok(idx) = idx {
544 items.remove(idx);
545 }
546 assert_eq!(ost.len(), items.len());
547 }
548 assert_eq!(ost.len(), items.len());
549 for item in 0..100 {
550 assert_eq!(ost.contains(&item), items.contains(&item));
551 }
552 Ok(())
553 }
554
555 #[test]
556 fn test_rank() {
557 let mut runner = TestRunner::default();
558 runner
559 .run(&vec(0i32..100, 0..100), test_rank_inner)
560 .unwrap();
561 }
562
563 fn test_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {
564 let mut ost = OrderStatisticTree::new(i32::cmp);
565 for item in &items {
566 ost.insert(*item);
567 }
568 items.sort();
569 for item in 0..100 {
570 let rank = ost.rank_range(&item);
571
572 let expected_rank = if items.contains(&item) {
573 let expected_rank_lower = items.iter().filter(|&x| *x < item).count();
574 let expected_rank_upper = items.iter().filter(|&x| *x <= item).count() - 1;
575 Ok(expected_rank_lower..=expected_rank_upper)
576 } else {
577 Err(items.iter().filter(|&x| *x < item).count())
578 };
579
580 assert_eq!(rank, expected_rank);
581 }
582 Ok(())
583 }
584
585 #[test]
586 fn test_unique_rank() {
587 let mut runner = TestRunner::default();
588 runner
589 .run(&vec(0i32..50, 0..100), test_unique_rank_inner)
590 .unwrap();
591 }
592
593 fn test_unique_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {
594 let mut ost = OrderStatisticTree::new(i32::cmp);
595 for item in &items {
596 ost.insert(*item);
597 }
598 assert_eq!(ost.len(), items.len());
599 items.sort();
600 items.dedup();
601 assert_eq!(ost.unique_len(), items.len());
602 for item in 0..50 {
603 let unique_rank = ost.rank_unique(&item);
604 let expected_unique_rank = if items.contains(&item) {
605 Ok(items.iter().filter(|&x| *x < item).count())
606 } else {
607 Err(items.iter().filter(|&x| *x < item).count())
608 };
609 assert_eq!(unique_rank, expected_unique_rank);
610 }
611 Ok(())
612 }
613
614 #[test]
615 fn test_empty() {
616 let ost = OrderStatisticTree::<i32>::new(i32::cmp);
617 assert!(ost.is_empty());
618 assert_eq!(ost.len(), 0);
619 assert_eq!(ost.unique_len(), 0);
620 assert!(ost.is_balanced());
621 assert!(!ost.contains(&1));
622 assert_eq!(ost.rank_range(&1), Err(0));
623 assert_eq!(ost.rank_unique(&1), Err(0));
624 }
625
626 #[test]
627 fn test_clear() {
628 let mut ost = OrderStatisticTree::new(i32::cmp);
629 for item in 0..10 {
630 ost.insert(item);
631 }
632 assert_eq!(ost.len(), 10);
633 assert_eq!(ost.unique_len(), 10);
634 ost.clear();
635 assert!(ost.is_empty());
636 }
637
638 #[test]
639 fn test_extend() {
640 let mut ost = OrderStatisticTree::new(i32::cmp);
641 ost.extend(0..10);
642 assert_eq!(ost.len(), 10);
643 assert_eq!(ost.unique_len(), 10);
644 for item in 0..10 {
645 assert!(ost.contains(&item));
646 }
647 }
648
649 #[test]
650 fn test_count() {
651 let mut ost = OrderStatisticTree::new(i32::cmp);
652 for item in &[1, 2, 2, 3, 3, 3] {
653 ost.insert(*item);
654 }
655 assert_eq!(ost.count(&1), 1);
656 assert_eq!(ost.count(&2), 2);
657 assert_eq!(ost.count(&3), 3);
658 assert_eq!(ost.count(&4), 0);
659 }
660
661 #[test]
662 fn test_get() {
663 let mut ost = OrderStatisticTree::new(i32::cmp);
664 let mut items = [3, 1, 4, 1, 5, 9, 2, 6, 5];
665 for item in items {
666 ost.insert(item);
667 }
668 items.sort();
669 for (i, item) in items.iter().enumerate() {
670 assert_eq!(ost.get(i), Some(item));
671 }
672 assert_eq!(ost.get(items.len()), None);
673 }
674}