polars_utils/parma/raw/
mod.rs

1use std::alloc::Layout;
2use std::marker::PhantomData;
3use std::mem::MaybeUninit;
4use std::ptr::without_provenance_mut;
5use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
6use std::sync::{Condvar, Mutex, MutexGuard, RwLock};
7
8mod key;
9mod probe;
10
11pub use key::Key;
12use probe::{Prober, TagGroup};
13
14#[repr(C)]
15struct AllocHeader<K: ?Sized, V> {
16    num_entries: usize,
17    num_deletions: AtomicUsize,
18
19    // Must be decremented when starting an insertion, waiting for a new
20    // allocation if the counter is zero.
21    claim_start_semaphore: AtomicUsize,
22
23    // Must be decremented when an insertion has claimed a slot in the entry table,
24    // used to ensure all entries are up-to-date before starting the rehashing process.
25    claim_done_barrier: AtomicUsize,
26
27    marker: PhantomData<(Box<K>, V)>,
28    align: [TagGroup; 0],
29}
30
31impl<K: ?Sized, V> AllocHeader<K, V> {
32    fn layout(num_entries: usize) -> Layout {
33        // Layout: AllocHeader [tags] [entries]
34        assert!(num_entries.is_power_of_two() && num_entries >= size_of::<TagGroup>());
35        let mut layout = Layout::new::<Self>();
36        layout = layout
37            .extend(Layout::array::<TagGroup>(num_entries / size_of::<TagGroup>()).unwrap())
38            .unwrap()
39            .0;
40        layout = layout
41            .extend(Layout::array::<AtomicPtr<EntryHeader<K, V>>>(num_entries).unwrap())
42            .unwrap()
43            .0;
44        layout
45    }
46
47    #[inline(always)]
48    unsafe fn tags(&self, alloc: *mut Self) -> &[TagGroup] {
49        unsafe {
50            let p = alloc.byte_add(size_of::<Self>());
51            core::slice::from_raw_parts(p.cast(), self.num_entries / size_of::<TagGroup>())
52        }
53    }
54
55    #[inline(always)]
56    unsafe fn entries(&self, alloc: *mut Self) -> &[AtomicPtr<EntryHeader<K, V>>] {
57        unsafe {
58            let p = alloc.byte_add(size_of::<Self>() + self.num_entries);
59            core::slice::from_raw_parts(p.cast(), self.num_entries)
60        }
61    }
62
63    #[inline(always)]
64    #[allow(clippy::mut_from_ref)] // Does not borrow from &self, but from alloc.
65    unsafe fn tags_mut(&self, alloc: *mut Self) -> &mut [TagGroup] {
66        unsafe {
67            let p = alloc.byte_add(size_of::<Self>());
68            core::slice::from_raw_parts_mut(p.cast(), self.num_entries / size_of::<TagGroup>())
69        }
70    }
71
72    #[inline(always)]
73    #[allow(clippy::mut_from_ref)] // Does not borrow from &self, but from alloc.
74    unsafe fn entries_mut(&self, alloc: *mut Self) -> &mut [AtomicPtr<EntryHeader<K, V>>] {
75        unsafe {
76            let p = alloc.byte_add(size_of::<Self>() + self.num_entries);
77            core::slice::from_raw_parts_mut(p.cast(), self.num_entries)
78        }
79    }
80
81    fn new(num_entries: usize) -> *mut Self {
82        let layout = Self::layout(num_entries);
83        unsafe {
84            let alloc = std::alloc::alloc(layout).cast::<Self>();
85            if alloc.is_null() {
86                std::alloc::handle_alloc_error(layout);
87            }
88
89            let max_load = probe::max_load(num_entries);
90            alloc.write(Self {
91                num_entries,
92                num_deletions: AtomicUsize::new(0),
93                claim_start_semaphore: AtomicUsize::new(max_load),
94                claim_done_barrier: AtomicUsize::new(max_load),
95                marker: PhantomData,
96                align: [],
97            });
98
99            let tags_p = alloc.byte_add(size_of::<Self>()) as *mut u8;
100            let tags: &mut [MaybeUninit<TagGroup>] =
101                core::slice::from_raw_parts_mut(tags_p.cast(), num_entries / size_of::<TagGroup>());
102            tags.fill_with(|| MaybeUninit::new(TagGroup::all_empty()));
103
104            let entries_p = alloc.byte_add(size_of::<Self>() + num_entries);
105            let entries: &mut [MaybeUninit<AtomicPtr<u8>>] =
106                core::slice::from_raw_parts_mut(entries_p.cast(), num_entries);
107            entries
108                .fill_with(|| MaybeUninit::new(AtomicPtr::new(without_provenance_mut(UNCLAIMED))));
109
110            alloc
111        }
112    }
113
114    unsafe fn free(slf: *mut Self) {
115        unsafe {
116            if slf != &raw const EMPTY_ALLOC_LOC as _ {
117                let layout = Self::layout((*slf).num_entries);
118                std::alloc::dealloc(slf.cast(), layout);
119            }
120        }
121    }
122
123    // Returns true if you may proceed with the insert attempt, false if you
124    // should wait for reallocation to occur.
125    fn start_claim_attempt(&self) -> bool {
126        self.claim_start_semaphore
127            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |attempts_left| {
128                attempts_left.checked_sub(1)
129            })
130            .is_ok()
131    }
132
133    fn abort_claim_attempt(
134        &self,
135        alloc_lock: &Mutex<TableLockState<K, V>>,
136        waiting_for_alloc: &Condvar,
137    ) {
138        let old = self.claim_start_semaphore.fetch_add(1, Ordering::Relaxed);
139        if old == 0 {
140            // We need to acquire the lock before notifying to prevent the race
141            // condition [r:check -> w:notify -> r:wait].
142            drop(alloc_lock.lock());
143            waiting_for_alloc.notify_all();
144        }
145    }
146
147    fn finish_claim_attempt(
148        &self,
149        alloc_lock: &Mutex<TableLockState<K, V>>,
150        waiting_for_alloc: &Condvar,
151    ) {
152        let old = self.claim_done_barrier.fetch_sub(1, Ordering::Release);
153        if old == 1 {
154            // We need to hold the lock while notifying to prevent the race
155            // condition [r:check -> w:notify -> r:wait].
156            drop(alloc_lock.lock());
157            waiting_for_alloc.notify_all();
158        }
159    }
160}
161
162// A pointer to an entry in the table must be in one of three states:
163//     0               = unclaimed entry
164//     p               = claimed entry
165//     usize::MAX      = claimed entry, now deleted
166// The only valid transitions are those which move down the above list.
167const UNCLAIMED: usize = 0;
168const DELETED: usize = usize::MAX;
169
170// The state field inside an entry is determined by the bottom three bits.
171// If the DELETE_BIT is set then the entry is considered to be deleted and the
172// upper bits contain the next pointer in the freelist. For this reason entries
173// have to be aligned to at least 8 bytes. Otherwise, the top bits contain
174// the hash.
175const INIT_BIT: usize = 0b001;
176const WAIT_BIT: usize = 0b010;
177const DELETE_BIT: usize = 0b100;
178
179#[repr(C, align(8))]
180struct EntryHeader<K: ?Sized, V> {
181    state: AtomicPtr<EntryHeader<K, V>>,
182    value: MaybeUninit<V>,
183    marker: PhantomData<K>,
184}
185
186impl<K: Key + ?Sized, V> EntryHeader<K, V> {
187    fn layout(key: &K) -> Layout {
188        let key_layout = Layout::from_size_align(key.size(), K::align()).unwrap();
189        Layout::new::<EntryHeader<K, V>>()
190            .extend(key_layout)
191            .unwrap()
192            .0
193    }
194
195    #[inline(always)]
196    fn state_ptr(entry: *mut Self) -> *mut AtomicPtr<EntryHeader<K, V>> {
197        unsafe { &raw mut (*entry).state }
198    }
199
200    #[inline(always)]
201    fn val_ptr(entry: *mut Self) -> *mut V {
202        unsafe { (&raw mut (*entry).value).cast() }
203    }
204
205    #[inline(always)]
206    unsafe fn key_ptr(entry: *mut Self) -> *mut u8 {
207        unsafe {
208            entry
209                .byte_add(size_of::<EntryHeader<K, V>>().next_multiple_of(K::align()))
210                .cast()
211        }
212    }
213
214    fn new(hash: usize, key: &K) -> *mut Self {
215        let layout = Self::layout(key);
216        unsafe {
217            let p = std::alloc::alloc(layout).cast::<Self>();
218            if p.is_null() {
219                std::alloc::handle_alloc_error(layout);
220            }
221            let state = without_provenance_mut(hash & !(INIT_BIT | WAIT_BIT | DELETE_BIT));
222            Self::state_ptr(p).write(AtomicPtr::new(state));
223            key.init(Self::key_ptr(p));
224            p
225        }
226    }
227
228    unsafe fn free(entry: *mut Self) {
229        unsafe {
230            let key = K::get(Self::key_ptr(entry));
231            let layout = Self::layout(key);
232            std::alloc::dealloc(entry.cast(), layout);
233        }
234    }
235
236    // Waits for this entry to be initialized. Returns true if successful, false
237    // if the value was deleted.
238    unsafe fn wait_for_init(
239        entry: *mut Self,
240        init_lock: &Mutex<()>,
241        waiting_for_init: &Condvar,
242    ) -> bool {
243        unsafe {
244            let state_loc = &*Self::state_ptr(entry);
245            let mut state = state_loc.load(Ordering::Acquire);
246            if state.addr() & (DELETE_BIT | INIT_BIT) != 0 {
247                return state.addr() & DELETE_BIT == 0;
248            }
249
250            // First acquire the lock then try setting the wait bit.
251            let mut guard = init_lock.lock().unwrap();
252            if let Err(new_state) = state_loc.compare_exchange(
253                state,
254                state.map_addr(|p| p | WAIT_BIT),
255                Ordering::Relaxed,
256                Ordering::Acquire,
257            ) {
258                state = new_state;
259            }
260
261            // Wait until init is complete.
262            loop {
263                if state.addr() & (DELETE_BIT | INIT_BIT) != 0 {
264                    return state.addr() & DELETE_BIT == 0;
265                }
266
267                guard = waiting_for_init.wait(guard).unwrap();
268                state = state_loc.load(Ordering::Acquire);
269            }
270        }
271    }
272}
273
274/// A concurrent hash table.
275#[repr(align(128))] // To avoid false sharing.
276pub struct RawTable<K: Key + ?Sized, V> {
277    cur_alloc: AtomicPtr<AllocHeader<K, V>>,
278    freelist_head: AtomicPtr<EntryHeader<K, V>>,
279    alloc_lock: Mutex<TableLockState<K, V>>,
280    waiting_for_alloc: Condvar,
281    init_lock: Mutex<()>,
282    waiting_for_init: Condvar,
283    rehash_lock: RwLock<()>,
284    marker: PhantomData<(Box<K>, V)>,
285}
286
287unsafe impl<K: Key + Send + ?Sized, V: Send> Send for RawTable<K, V> {}
288unsafe impl<K: Key + Send + Sync + ?Sized, V: Send + Sync> Sync for RawTable<K, V> {}
289
290struct TableLockState<K: ?Sized, V> {
291    old_allocs: Vec<*mut AllocHeader<K, V>>,
292}
293
294impl<K: Key + ?Sized, V> RawTable<K, V> {
295    /// Creates a new [`RawTable`].
296    pub const fn new() -> Self {
297        Self {
298            cur_alloc: AtomicPtr::new(&raw const EMPTY_ALLOC_LOC as _),
299            freelist_head: AtomicPtr::new(core::ptr::null_mut()),
300            alloc_lock: Mutex::new(TableLockState {
301                old_allocs: Vec::new(),
302            }),
303            waiting_for_alloc: Condvar::new(),
304            init_lock: Mutex::new(()),
305            waiting_for_init: Condvar::new(),
306            rehash_lock: RwLock::new(()),
307            marker: PhantomData,
308        }
309    }
310
311    /// Creates a new [`RawTable`] that will not reallocate before `capacity` insertions are done.
312    pub fn with_capacity(capacity: usize) -> Self {
313        if capacity == 0 {
314            return Self::new();
315        }
316        Self {
317            cur_alloc: AtomicPtr::new(AllocHeader::new(probe::min_entries_for_load(capacity))),
318            freelist_head: AtomicPtr::new(core::ptr::null_mut()),
319            alloc_lock: Mutex::new(TableLockState {
320                old_allocs: Vec::new(),
321            }),
322            waiting_for_alloc: Condvar::new(),
323            init_lock: Mutex::new(()),
324            waiting_for_init: Condvar::new(),
325            rehash_lock: RwLock::new(()),
326            marker: PhantomData,
327        }
328    }
329
330    fn start_insert_attempt(&self) -> *mut AllocHeader<K, V> {
331        unsafe {
332            let alloc = self.cur_alloc.load(Ordering::Acquire);
333            if (*alloc).start_claim_attempt() {
334                return alloc;
335            }
336
337            let mut guard = self.alloc_lock.lock().unwrap();
338            loop {
339                let alloc = self.cur_alloc.load(Ordering::Acquire);
340                let header = &*alloc;
341                if header.start_claim_attempt() {
342                    return alloc;
343                }
344
345                let barrier = header.claim_done_barrier.load(Ordering::Acquire);
346                if barrier == 0 {
347                    guard = self.realloc(guard);
348                } else {
349                    guard = self.waiting_for_alloc.wait(guard).unwrap();
350                }
351            }
352        }
353    }
354
355    unsafe fn realloc<'a>(
356        &'a self,
357        mut alloc_guard: MutexGuard<'a, TableLockState<K, V>>,
358    ) -> MutexGuard<'a, TableLockState<K, V>> {
359        unsafe {
360            let old_alloc = self.cur_alloc.load(Ordering::Relaxed);
361            let old_header = &*old_alloc;
362            let upper_bound_len = probe::max_load(old_header.num_entries)
363                - old_header.num_deletions.load(Ordering::Acquire);
364
365            let num_entries = (upper_bound_len * 2).next_power_of_two().max(32);
366            let alloc = AllocHeader::<K, V>::new(num_entries);
367            let header = &*alloc;
368
369            // Rehash old entries. We must hold the rehash lock exclusively to prevent those operations
370            // which may not occur during rehashing.
371            let rehash_guard = self.rehash_lock.write();
372            let mut entries_reinserted = 0;
373            for entry in old_header.entries(old_alloc) {
374                let entry_ptr = entry.load(Ordering::Relaxed);
375                if entry_ptr.addr() != DELETED && entry_ptr.addr() != UNCLAIMED {
376                    let state = (*EntryHeader::state_ptr(entry_ptr)).load(Ordering::Relaxed);
377                    if state.addr() & DELETE_BIT == 0 {
378                        Self::insert_uniq_entry_exclusive(alloc, state.addr(), entry_ptr);
379                        entries_reinserted += 1;
380                    }
381                }
382            }
383
384            // Publish the new allocation.
385            header
386                .claim_start_semaphore
387                .fetch_sub(entries_reinserted, Ordering::Relaxed);
388            header
389                .claim_done_barrier
390                .fetch_sub(entries_reinserted, Ordering::Relaxed);
391            alloc_guard.old_allocs.push(old_alloc);
392            self.cur_alloc.store(alloc, Ordering::Release);
393            drop(rehash_guard);
394            self.waiting_for_alloc.notify_all();
395            alloc_guard
396        }
397    }
398
399    unsafe fn try_init_entry_val<E>(
400        &self,
401        hash: usize,
402        header: &AllocHeader<K, V>,
403        entry: &AtomicPtr<EntryHeader<K, V>>,
404        new_entry_ptr: *mut EntryHeader<K, V>,
405        val_f: impl FnOnce(&K) -> Result<V, E>,
406    ) -> Result<*mut EntryHeader<K, V>, E> {
407        unsafe {
408            let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
409                let key_ptr = EntryHeader::key_ptr(new_entry_ptr);
410                let key = K::get(key_ptr);
411                EntryHeader::val_ptr(new_entry_ptr).write(val_f(key)?);
412                Ok(())
413            }));
414            let state = &*EntryHeader::state_ptr(new_entry_ptr);
415            let old_state;
416            if !matches!(r, Ok(Ok(_))) {
417                let old_head = self.freelist_head.swap(new_entry_ptr, Ordering::AcqRel);
418                old_state = state.swap(old_head.map_addr(|a| a | DELETE_BIT), Ordering::Release);
419                entry.store(without_provenance_mut(DELETED), Ordering::Relaxed);
420                header.num_deletions.fetch_add(1, Ordering::Release);
421            } else {
422                old_state = state.swap(
423                    without_provenance_mut((hash & !(WAIT_BIT | DELETE_BIT)) | INIT_BIT),
424                    Ordering::Release,
425                )
426            };
427
428            if old_state.addr() & WAIT_BIT != 0 {
429                // We need to hold the lock while notifying to prevent the race
430                // condition [r:check -> w:notify -> r:wait].
431                drop(self.init_lock.lock());
432                self.waiting_for_init.notify_all();
433            }
434
435            match r {
436                Ok(Ok(())) => Ok(new_entry_ptr),
437                Ok(Err(e)) => Err(e),
438                Err(panic) => std::panic::resume_unwind(panic),
439            }
440        }
441    }
442
443    fn find_impl(
444        alloc: *mut AllocHeader<K, V>,
445        hash: usize,
446        mut eq: impl FnMut(&K) -> bool,
447    ) -> Result<(*mut EntryHeader<K, V>, usize), Prober> {
448        unsafe {
449            let mut prober = Prober::new(hash);
450
451            let header = &*alloc;
452            let tags = header.tags(alloc);
453            let entries = header.entries(alloc);
454            let group_mask = TagGroup::idx_mask(header.num_entries);
455            let mut needle = TagGroup::all_occupied(hash);
456            loop {
457                let group_idx = prober.get() & group_mask;
458                let mut tag_group = TagGroup::load(tags.get_unchecked(group_idx));
459                let mut matches = tag_group.matches(&mut needle);
460                while matches.has_matches() {
461                    let idx_in_group = matches.get();
462                    let entry_idx = size_of::<TagGroup>() * group_idx + idx_in_group;
463                    let entry_ptr = entries.get_unchecked(entry_idx).load(Ordering::Acquire);
464
465                    // Matching tag but unclaimed, racy insert in process but definitely missing.
466                    if entry_ptr.addr() == UNCLAIMED {
467                        return Err(prober);
468                    }
469
470                    if entry_ptr.addr() != DELETED {
471                        let state = (*EntryHeader::state_ptr(entry_ptr)).load(Ordering::Acquire);
472                        if state.addr() & DELETE_BIT == 0
473                            && eq(K::get(EntryHeader::<K, V>::key_ptr(entry_ptr)))
474                        {
475                            // Not deleted and a key hit, either a racy insert or a hit.
476                            return if state.addr() & INIT_BIT != 0 {
477                                Ok((entry_ptr, entry_idx))
478                            } else {
479                                Err(prober)
480                            };
481                        }
482                    }
483                    matches.advance();
484                }
485
486                if tag_group.empties().has_matches() {
487                    return Err(prober);
488                }
489
490                prober.advance();
491            }
492        }
493    }
494
495    fn try_find_or_insert_impl<E>(
496        &self,
497        orig_probe_alloc: *mut AllocHeader<K, V>,
498        mut prober: Prober,
499        hash: usize,
500        key: &K,
501        val_f: impl FnOnce(&K) -> Result<V, E>,
502        mut eq: impl FnMut(&K) -> bool,
503    ) -> Result<*mut EntryHeader<K, V>, E> {
504        unsafe {
505            let new_entry_ptr = EntryHeader::<K, V>::new(hash, key);
506
507            let alloc = self.start_insert_attempt();
508            if alloc != orig_probe_alloc {
509                prober = Prober::new(hash);
510            }
511
512            let header = &*alloc;
513            let tags = header.tags(alloc);
514            let entries = header.entries(alloc);
515            let group_mask = TagGroup::idx_mask(header.num_entries);
516            let mut needle = TagGroup::all_occupied(hash);
517
518            'probe_loop: loop {
519                let group_idx = prober.get() & group_mask;
520                let mut tag_group = TagGroup::load(tags.get_unchecked(group_idx));
521                let matches = tag_group.matches(&mut needle);
522                let empties = tag_group.empties();
523                let mut insert_locs = matches | empties;
524                while insert_locs.has_matches() {
525                    let idx_in_group = insert_locs.get();
526
527                    // Insert a new tag if this insert location.
528                    if empties.has_match_at(idx_in_group) {
529                        if !tags.get_unchecked(group_idx).try_occupy(
530                            &mut tag_group,
531                            idx_in_group,
532                            hash,
533                        ) {
534                            continue 'probe_loop;
535                        }
536                    }
537
538                    let entry_idx = size_of::<TagGroup>() * group_idx + idx_in_group;
539                    let entry = entries.get_unchecked(entry_idx);
540                    let mut entry_ptr = entry.load(Ordering::Acquire);
541                    if entry_ptr.addr() == UNCLAIMED {
542                        // Try to claim this entry.
543                        match entry.compare_exchange(
544                            entry_ptr,
545                            new_entry_ptr,
546                            Ordering::Release,
547                            Ordering::Acquire,
548                        ) {
549                            Ok(_) => {
550                                header.finish_claim_attempt(
551                                    &self.alloc_lock,
552                                    &self.waiting_for_alloc,
553                                );
554                                return self.try_init_entry_val(
555                                    hash,
556                                    header,
557                                    entry,
558                                    new_entry_ptr,
559                                    val_f,
560                                );
561                            },
562                            Err(ev) => entry_ptr = ev,
563                        }
564                    }
565
566                    // We couldn't claim the entry, see if our key is the same as
567                    // whoever claimed this entry, assuming it's not deleted.
568                    if entry_ptr.addr() != DELETED {
569                        let entry_key = K::get(EntryHeader::key_ptr(entry_ptr));
570                        if eq(entry_key) {
571                            if EntryHeader::wait_for_init(
572                                entry_ptr,
573                                &self.init_lock,
574                                &self.waiting_for_init,
575                            ) {
576                                EntryHeader::free(new_entry_ptr);
577                                header
578                                    .abort_claim_attempt(&self.alloc_lock, &self.waiting_for_alloc);
579                                return Ok(entry_ptr);
580                            }
581                        }
582                    }
583
584                    insert_locs.advance();
585                }
586
587                prober.advance();
588            }
589        }
590    }
591
592    unsafe fn insert_uniq_entry_exclusive(
593        alloc: *mut AllocHeader<K, V>,
594        hash: usize,
595        uniq_entry_ptr: *mut EntryHeader<K, V>,
596    ) {
597        unsafe {
598            let header = &mut *alloc;
599            let tags = header.tags_mut(alloc);
600            let entries = header.entries_mut(alloc);
601            let group_mask = TagGroup::idx_mask(header.num_entries);
602
603            let mut prober = Prober::new(hash);
604            loop {
605                let group_idx = prober.get() & group_mask;
606                let tag_group = tags.get_unchecked_mut(group_idx);
607                let empties = tag_group.empties();
608                if empties.has_matches() {
609                    let idx_in_group = empties.get();
610                    tag_group.occupy_mut(idx_in_group, hash);
611                    let entry_idx = size_of::<TagGroup>() * group_idx + idx_in_group;
612                    *entries.get_unchecked_mut(entry_idx).get_mut() = uniq_entry_ptr;
613                    return;
614                }
615
616                prober.advance();
617            }
618        }
619    }
620
621    /// Free any resources which are no longer necessary.
622    ///
623    /// # Safety
624    /// Until drop_guard gets called, there may not be any alive references
625    /// returned by the [`RawTable`], or concurrent other operations.
626    pub unsafe fn gc<F: FnOnce()>(&self, drop_guard: F) {
627        let mut freelist_head = self
628            .freelist_head
629            .swap(core::ptr::null_mut(), Ordering::Acquire);
630        let old_allocs = core::mem::take(&mut self.alloc_lock.lock().unwrap().old_allocs);
631        drop_guard();
632
633        unsafe {
634            while !freelist_head.is_null() {
635                let state = *(*EntryHeader::state_ptr(freelist_head)).get_mut();
636                if state.addr() & INIT_BIT != 0 {
637                    core::ptr::drop_in_place(EntryHeader::val_ptr(freelist_head));
638                }
639                K::drop_in_place(EntryHeader::key_ptr(freelist_head));
640                EntryHeader::free(freelist_head);
641                freelist_head = state.map_addr(|a| a & !(INIT_BIT | WAIT_BIT | DELETE_BIT));
642            }
643
644            for alloc in old_allocs {
645                AllocHeader::free(alloc);
646            }
647        }
648    }
649
650    /// Finds the value corresponding to a key with the given hash and equality function.
651    pub fn get(&self, hash: u64, eq: impl FnMut(&K) -> bool) -> Option<&V> {
652        unsafe {
653            let cur_alloc = self.cur_alloc.load(Ordering::Acquire);
654            let entry_ptr = Self::find_impl(cur_alloc, hash as usize, eq).ok()?.0;
655            Some(&*EntryHeader::val_ptr(entry_ptr))
656        }
657    }
658
659    /// Finds the value corresponding to a key with the given hash and equality function, or insert
660    /// a new one if the key does not exist.
661    ///
662    /// `val_f` is guaranteed to only be called when inserting a new key not currently found in the
663    /// table, even if multiple concurrent inserts occur. The key reference passed to `val_f` lives
664    /// as long as the new entry will.
665    pub fn get_or_insert_with(
666        &self,
667        hash: u64,
668        key: &K,
669        eq: impl FnMut(&K) -> bool,
670        val_f: impl FnOnce(&K) -> V,
671    ) -> &V {
672        unsafe {
673            self.try_get_or_insert_with::<()>(hash, key, eq, |k| Ok(val_f(k)))
674                .unwrap_unchecked()
675        }
676    }
677
678    /// Finds the value corresponding to a key with the given hash and equality function, or insert
679    /// a new one if the key does not exist.
680    ///
681    /// `val_f` is guaranteed to only be called when inserting a new key not currently found in the
682    /// table, even if multiple concurrent inserts occur. The key reference passed to `val_f` lives
683    /// as long as the new entry will.
684    pub fn try_get_or_insert_with<E>(
685        &self,
686        hash: u64,
687        key: &K,
688        mut eq: impl FnMut(&K) -> bool,
689        val_f: impl FnOnce(&K) -> Result<V, E>,
690    ) -> Result<&V, E> {
691        unsafe {
692            let cur_alloc = self.cur_alloc.load(Ordering::Acquire);
693            match Self::find_impl(cur_alloc, hash as usize, &mut eq) {
694                Ok((entry_ptr, _)) => Ok(&*EntryHeader::val_ptr(entry_ptr)),
695                Err(prober) => {
696                    let entry_ptr = self.try_find_or_insert_impl(
697                        cur_alloc,
698                        prober,
699                        hash as usize,
700                        key,
701                        val_f,
702                        eq,
703                    )?;
704                    Ok(&*EntryHeader::val_ptr(entry_ptr))
705                },
706            }
707        }
708    }
709
710    /// Finds and removes the value corresponding to a key with the given hash and equality function.
711    ///
712    /// Note that the value is not dropped until [`RawTable::gc`] is called or the [`RawTable`] is dropped.
713    pub fn remove(&self, hash: u64, eq: impl FnMut(&K) -> bool) -> Option<&V> {
714        unsafe {
715            // TODO: perhaps deletions could be done during a rehash by synchronizing on state?
716            let _rehash_guard = self.rehash_lock.read();
717            let alloc = self.cur_alloc.load(Ordering::Acquire);
718            let header = &*alloc;
719            let (entry_ptr, entry_idx) = Self::find_impl(alloc, hash as usize, eq).ok()?;
720
721            let state = &(*entry_ptr).state;
722            let old_state = state
723                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |s| {
724                    if s.addr() & DELETE_BIT != 0 {
725                        return None;
726                    }
727
728                    Some(s.map_addr(|a| a | DELETE_BIT))
729                })
730                .ok()?;
731
732            let group_idx = entry_idx / size_of::<TagGroup>();
733            let idx_in_group = entry_idx % size_of::<TagGroup>();
734            let old_head = self.freelist_head.swap(entry_ptr, Ordering::AcqRel);
735            state.store(
736                old_head.map_addr(|a| a | (old_state.addr() & INIT_BIT)),
737                Ordering::Release,
738            );
739            header
740                .entries(alloc)
741                .get_unchecked(entry_idx)
742                .store(without_provenance_mut(DELETED), Ordering::Relaxed);
743            header
744                .tags(alloc)
745                .get_unchecked(group_idx)
746                .delete(idx_in_group);
747            header.num_deletions.fetch_add(1, Ordering::Release);
748            Some(&*EntryHeader::val_ptr(entry_ptr))
749        }
750    }
751}
752
753impl<K: Key + ?Sized, V> Default for RawTable<K, V> {
754    fn default() -> Self {
755        Self::new()
756    }
757}
758
759impl<K: Key + ?Sized, V> Drop for RawTable<K, V> {
760    fn drop(&mut self) {
761        unsafe {
762            self.gc(|| {});
763
764            let alloc = *self.cur_alloc.get_mut();
765            let header = &*alloc;
766            for entry in header.entries_mut(alloc) {
767                let entry_ptr = *(*entry).get_mut();
768                if entry_ptr.is_null() || entry_ptr.addr() == DELETED {
769                    continue;
770                }
771
772                let state = (*EntryHeader::state_ptr(entry_ptr)).get_mut();
773                if state.addr() & INIT_BIT != 0 {
774                    core::ptr::drop_in_place(EntryHeader::val_ptr(entry_ptr));
775                }
776                K::drop_in_place(EntryHeader::key_ptr(entry_ptr));
777                EntryHeader::free(entry_ptr);
778            }
779            AllocHeader::free(alloc);
780        }
781    }
782}
783
784#[repr(C)]
785struct DummyAlloc {
786    header: AllocHeader<(), ()>,
787    tags: [TagGroup; 1],
788    entries: [AtomicPtr<EntryHeader<(), ()>>; size_of::<TagGroup>()],
789}
790
791static EMPTY_ALLOC_LOC: DummyAlloc = DummyAlloc {
792    header: AllocHeader {
793        num_entries: size_of::<TagGroup>(),
794        num_deletions: AtomicUsize::new(0),
795        claim_start_semaphore: AtomicUsize::new(0),
796        claim_done_barrier: AtomicUsize::new(0),
797        marker: PhantomData,
798        align: [],
799    },
800    tags: [TagGroup::all_empty()],
801    entries: [
802        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
803        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
804        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
805        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
806        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
807        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
808        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
809        AtomicPtr::new(without_provenance_mut(UNCLAIMED)),
810    ],
811};