polars_utils/
cache.rs

1use std::borrow::Borrow;
2use std::hash::{BuildHasher, Hash};
3
4use foldhash::fast::RandomState;
5use hashbrown::HashTable;
6use hashbrown::hash_table::Entry;
7use slotmap::{Key, SlotMap, new_key_type};
8
9/// A cached function that use `LruCache`.
10pub struct LruCachedFunc<T, R, F> {
11    func: F,
12    cache: LruCache<T, R>,
13}
14
15impl<T, R, F> LruCachedFunc<T, R, F>
16where
17    F: FnMut(T) -> R,
18    T: std::hash::Hash + Eq + Clone,
19    R: Copy,
20{
21    pub fn new(func: F, size: usize) -> Self {
22        Self {
23            func,
24            cache: LruCache::with_capacity(size),
25        }
26    }
27
28    pub fn eval(&mut self, x: T, use_cache: bool) -> R {
29        if use_cache {
30            *self
31                .cache
32                .get_or_insert_with(&x, |xr| (self.func)(xr.clone()))
33        } else {
34            (self.func)(x)
35        }
36    }
37}
38
39new_key_type! {
40    struct LruKey;
41}
42
43pub struct LruCache<K, V, S = RandomState> {
44    table: HashTable<LruKey>,
45    elements: SlotMap<LruKey, LruEntry<K, V>>,
46    max_capacity: usize,
47    most_recent: LruKey,
48    least_recent: LruKey,
49    build_hasher: S,
50}
51
52struct LruEntry<K, V> {
53    key: K,
54    value: V,
55    list: LruListNode,
56}
57
58#[derive(Copy, Clone, Default)]
59struct LruListNode {
60    more_recent: LruKey,
61    less_recent: LruKey,
62}
63
64impl<K, V> LruCache<K, V> {
65    pub fn with_capacity(capacity: usize) -> Self {
66        Self::with_capacity_and_hasher(capacity, RandomState::default())
67    }
68}
69
70impl<K, V, S> LruCache<K, V, S> {
71    pub fn with_capacity_and_hasher(max_capacity: usize, build_hasher: S) -> Self {
72        assert!(max_capacity > 0);
73        Self {
74            // Allocate one more capacity to prevent double-lookup or realloc
75            // when doing get_or_insert when full.
76            table: HashTable::with_capacity(max_capacity + 1),
77            elements: SlotMap::with_capacity_and_key(max_capacity + 1),
78            max_capacity,
79            most_recent: LruKey::null(),
80            least_recent: LruKey::null(),
81            build_hasher,
82        }
83    }
84}
85
86impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
87    fn lru_list_unlink(&mut self, lru_key: LruKey) {
88        let list = self.elements[lru_key].list;
89        if let Some(more_recent) = self.elements.get_mut(list.more_recent) {
90            more_recent.list.less_recent = list.less_recent;
91        } else {
92            self.most_recent = list.less_recent;
93        }
94        if let Some(less_recent) = self.elements.get_mut(list.less_recent) {
95            less_recent.list.more_recent = list.more_recent;
96        } else {
97            self.least_recent = list.more_recent;
98        }
99    }
100
101    fn lru_list_insert_mru(&mut self, lru_key: LruKey) {
102        let prev_most_recent_key = self.most_recent;
103        self.most_recent = lru_key;
104        if let Some(prev_most_recent) = self.elements.get_mut(prev_most_recent_key) {
105            prev_most_recent.list.more_recent = lru_key;
106        } else {
107            self.least_recent = lru_key;
108        }
109        let list = &mut self.elements[lru_key].list;
110        list.more_recent = LruKey::null();
111        list.less_recent = prev_most_recent_key;
112    }
113
114    pub fn pop_lru(&mut self) -> Option<(K, V)> {
115        if self.elements.is_empty() {
116            return None;
117        }
118        let lru_key = self.least_recent;
119        let hash = self.build_hasher.hash_one(&self.elements[lru_key].key);
120        self.lru_list_unlink(lru_key);
121        let lru_entry = self.elements.remove(lru_key).unwrap();
122        self.table
123            .find_entry(hash, |k| *k == lru_key)
124            .unwrap()
125            .remove();
126        Some((lru_entry.key, lru_entry.value))
127    }
128
129    pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
130    where
131        K: Borrow<Q>,
132        Q: Hash + Eq + ?Sized,
133    {
134        let hash = self.build_hasher.hash_one(key);
135        let lru_key = *self
136            .table
137            .find(hash, |lru_key| self.elements[*lru_key].key.borrow() == key)?;
138        self.lru_list_unlink(lru_key);
139        self.lru_list_insert_mru(lru_key);
140        let lru_node = self.elements.get(lru_key).unwrap();
141        Some(&lru_node.value)
142    }
143
144    /// Returns the old value, if any.
145    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
146        let hash = self.build_hasher.hash_one(&key);
147        match self.table.entry(
148            hash,
149            |lru_key| self.elements[*lru_key].key == key,
150            |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key),
151        ) {
152            Entry::Occupied(o) => {
153                let lru_key = *o.get();
154                self.lru_list_unlink(lru_key);
155                self.lru_list_insert_mru(lru_key);
156                Some(core::mem::replace(&mut self.elements[lru_key].value, value))
157            },
158
159            Entry::Vacant(v) => {
160                let lru_entry = LruEntry {
161                    key,
162                    value,
163                    list: LruListNode::default(),
164                };
165                let lru_key = self.elements.insert(lru_entry);
166                v.insert(lru_key);
167                self.lru_list_insert_mru(lru_key);
168                if self.elements.len() > self.max_capacity {
169                    self.pop_lru();
170                }
171                None
172            },
173        }
174    }
175
176    pub fn get_or_insert_with<Q, F: FnOnce(&Q) -> V>(&mut self, key: &Q, f: F) -> &mut V
177    where
178        K: Borrow<Q>,
179        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
180    {
181        enum Never {}
182        let Ok(ret) = self.try_get_or_insert_with::<Q, Never, _>(key, |k| Ok(f(k)));
183        ret
184    }
185
186    pub fn try_get_or_insert_with<Q, E, F: FnOnce(&Q) -> Result<V, E>>(
187        &mut self,
188        key: &Q,
189        f: F,
190    ) -> Result<&mut V, E>
191    where
192        K: Borrow<Q>,
193        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
194    {
195        let hash = self.build_hasher.hash_one(key);
196        match self.table.entry(
197            hash,
198            |lru_key| self.elements[*lru_key].key.borrow() == key,
199            |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key),
200        ) {
201            Entry::Occupied(o) => {
202                let lru_key = *o.get();
203                if lru_key != self.most_recent {
204                    self.lru_list_unlink(lru_key);
205                    self.lru_list_insert_mru(lru_key);
206                }
207                Ok(&mut self.elements[lru_key].value)
208            },
209
210            Entry::Vacant(v) => {
211                let lru_entry = LruEntry {
212                    value: f(key)?,
213                    key: key.to_owned(),
214                    list: LruListNode::default(),
215                };
216                let lru_key = self.elements.insert(lru_entry);
217                v.insert(lru_key);
218                self.lru_list_insert_mru(lru_key);
219                if self.elements.len() > self.max_capacity {
220                    self.pop_lru();
221                }
222                Ok(&mut self.elements[lru_key].value)
223            },
224        }
225    }
226}