polars_core/chunked_array/logical/categorical/
merge.rs

1use std::borrow::Cow;
2
3use super::*;
4use crate::series::IsSorted;
5use crate::utils::align_chunks_binary;
6
7fn slots_to_mut(slots: &Utf8ViewArray) -> MutablePlString {
8    slots.clone().make_mut()
9}
10
11struct State {
12    map: PlHashMap<u32, u32>,
13    slots: MutablePlString,
14}
15
16#[derive(Default)]
17pub struct GlobalRevMapMerger {
18    id: u32,
19    original: Arc<RevMapping>,
20    // only initiate state when
21    // we encounter a rev-map from a different source,
22    // but from the same string cache
23    state: Option<State>,
24}
25
26impl GlobalRevMapMerger {
27    pub fn new(rev_map: Arc<RevMapping>) -> Self {
28        let RevMapping::Global(_, _, id) = rev_map.as_ref() else {
29            unreachable!()
30        };
31
32        GlobalRevMapMerger {
33            state: None,
34            id: *id,
35            original: rev_map,
36        }
37    }
38
39    fn init_state(&mut self) {
40        let RevMapping::Global(map, slots, _) = self.original.as_ref() else {
41            unreachable!()
42        };
43        self.state = Some(State {
44            map: (*map).clone(),
45            slots: slots_to_mut(slots),
46        })
47    }
48
49    pub fn merge_map(&mut self, rev_map: &Arc<RevMapping>) -> PolarsResult<()> {
50        // happy path they come from the same source
51        if Arc::ptr_eq(&self.original, rev_map) {
52            return Ok(());
53        }
54
55        let RevMapping::Global(map, slots, id) = rev_map.as_ref() else {
56            polars_bail!(string_cache_mismatch)
57        };
58        polars_ensure!(*id == self.id, string_cache_mismatch);
59
60        if self.state.is_none() {
61            self.init_state()
62        }
63        let state = self.state.as_mut().unwrap();
64
65        for (cat, idx) in map.iter() {
66            state.map.entry(*cat).or_insert_with(|| {
67                // SAFETY:
68                // within bounds
69                let str_val = unsafe { slots.value_unchecked(*idx as usize) };
70                let new_idx = state.slots.len() as u32;
71                state.slots.push(Some(str_val));
72
73                new_idx
74            });
75        }
76        Ok(())
77    }
78
79    pub fn finish(self) -> Arc<RevMapping> {
80        match self.state {
81            None => self.original,
82            Some(state) => {
83                let new_rev = RevMapping::Global(state.map, state.slots.into(), self.id);
84                Arc::new(new_rev)
85            },
86        }
87    }
88}
89
90fn merge_local_rhs_categorical<'a>(
91    categories: &'a Utf8ViewArray,
92    ca_right: &'a CategoricalChunked,
93) -> Result<(UInt32Chunked, Arc<RevMapping>), PolarsError> {
94    // Counterpart of the GlobalRevmapMerger.
95    // In case of local categorical we also need to change the physicals not only the revmap
96
97    polars_warn!(
98        CategoricalRemappingWarning,
99        "Local categoricals have different encodings, expensive re-encoding is done \
100        to perform this merge operation. Consider using a StringCache or an Enum type \
101        if the categories are known in advance"
102    );
103
104    let RevMapping::Local(cats_right, _) = &**ca_right.get_rev_map() else {
105        unreachable!()
106    };
107
108    let cats_left_hashmap = PlHashMap::from_iter(
109        categories
110            .values_iter()
111            .enumerate()
112            .map(|(k, v)| (v, k as u32)),
113    );
114    let mut new_categories = slots_to_mut(categories);
115    let mut idx_mapping = PlHashMap::with_capacity(cats_right.len());
116
117    for (idx, s) in cats_right.values_iter().enumerate() {
118        if let Some(v) = cats_left_hashmap.get(&s) {
119            idx_mapping.insert(idx as u32, *v);
120        } else {
121            idx_mapping.insert(idx as u32, new_categories.len() as u32);
122            new_categories.push(Some(s));
123        }
124    }
125    let new_rev_map = Arc::new(RevMapping::build_local(new_categories.into()));
126    Ok((
127        ca_right
128            .physical
129            .apply(|opt_v| opt_v.map(|v| *idx_mapping.get(&v).unwrap())),
130        new_rev_map,
131    ))
132}
133
134pub trait CategoricalMergeOperation {
135    fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult<UInt32Chunked>;
136}
137
138// Make the right categorical compatible with the left while applying the merge operation
139pub fn call_categorical_merge_operation<I: CategoricalMergeOperation>(
140    cat_left: &CategoricalChunked,
141    cat_right: &CategoricalChunked,
142    merge_ops: I,
143) -> PolarsResult<CategoricalChunked> {
144    let rev_map_left = cat_left.get_rev_map();
145    let rev_map_right = cat_right.get_rev_map();
146    let (mut new_physical, new_rev_map) = match (&**rev_map_left, &**rev_map_right) {
147        (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => {
148            let mut rev_map_merger = GlobalRevMapMerger::new(rev_map_left.clone());
149            rev_map_merger.merge_map(rev_map_right)?;
150            (
151                merge_ops.finish(cat_left.physical(), cat_right.physical())?,
152                rev_map_merger.finish(),
153            )
154        },
155        (RevMapping::Local(_, idl), RevMapping::Local(_, idr))
156            if idl == idr && cat_left.is_enum() == cat_right.is_enum() =>
157        {
158            (
159                merge_ops.finish(cat_left.physical(), cat_right.physical())?,
160                rev_map_left.clone(),
161            )
162        },
163        (RevMapping::Local(categorical, _), RevMapping::Local(_, _))
164            if !cat_left.is_enum() && !cat_right.is_enum() =>
165        {
166            let (rhs_physical, rev_map) = merge_local_rhs_categorical(categorical, cat_right)?;
167            (
168                merge_ops.finish(cat_left.physical(), &rhs_physical)?,
169                rev_map,
170            )
171        },
172        (RevMapping::Local(_, _), RevMapping::Local(_, _))
173            if cat_left.is_enum() | cat_right.is_enum() =>
174        {
175            polars_bail!(ComputeError: "can not merge incompatible Enum types")
176        },
177        _ => polars_bail!(string_cache_mismatch),
178    };
179    // During merge operation, the sorted flag might get set on the underlying physical.
180    // Ensure that the sorted flag is not set if we use lexical order
181    if cat_left.uses_lexical_ordering() {
182        new_physical.set_sorted_flag(IsSorted::Not)
183    }
184
185    // SAFETY: physical and rev map are correctly constructed above
186    unsafe {
187        Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
188            new_physical,
189            new_rev_map,
190            cat_left.is_enum(),
191            cat_left.get_ordering(),
192        ))
193    }
194}
195
196struct DoNothing;
197impl CategoricalMergeOperation for DoNothing {
198    fn finish(self, _lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult<UInt32Chunked> {
199        Ok(rhs.clone())
200    }
201}
202
203// Make the right categorical compatible with the left
204pub fn make_rhs_categoricals_compatible(
205    ca_left: &CategoricalChunked,
206    ca_right: &CategoricalChunked,
207) -> PolarsResult<(CategoricalChunked, CategoricalChunked)> {
208    let new_ca_right = call_categorical_merge_operation(ca_left, ca_right, DoNothing)?;
209
210    // Alter rev map of left
211    let mut new_ca_left = ca_left.clone();
212    // SAFETY: We just made both rev maps compatible only appended categories
213    unsafe {
214        new_ca_left.set_rev_map(
215            new_ca_right.get_rev_map().clone(),
216            ca_left.get_rev_map().len() == new_ca_right.get_rev_map().len(),
217        )
218    };
219
220    Ok((new_ca_left, new_ca_right))
221}
222
223pub fn make_rhs_list_categoricals_compatible(
224    mut list_ca_left: ListChunked,
225    list_ca_right: ListChunked,
226) -> PolarsResult<(ListChunked, ListChunked)> {
227    // Make categoricals compatible
228
229    let cat_left = list_ca_left.get_inner();
230    let cat_right = list_ca_right.get_inner();
231    let (cat_left, cat_right) =
232        make_rhs_categoricals_compatible(cat_left.categorical()?, cat_right.categorical()?)?;
233
234    // we only appended categories to the rev_map at the end, so only change the inner dtype
235    list_ca_left.set_inner_dtype(cat_left.dtype().clone());
236
237    // We changed the physicals and the rev_map, offsets and validity buffers are still good
238    let (list_ca_right, cat_physical): (Cow<ListChunked>, Cow<UInt32Chunked>) =
239        align_chunks_binary(&list_ca_right, cat_right.physical());
240    let mut list_ca_right = list_ca_right.into_owned();
241    // SAFETY:
242    // Chunks are aligned, length / dtype remains correct
243    unsafe {
244        list_ca_right
245            .downcast_iter_mut()
246            .zip(cat_physical.chunks())
247            .for_each(|(arr, new_phys)| {
248                *arr = ListArray::new(
249                    arr.dtype().clone(),
250                    arr.offsets().clone(),
251                    new_phys.clone(),
252                    arr.validity().cloned(),
253                )
254            });
255    }
256    // reset the sorted flag and add extra categories back in
257    list_ca_right.set_sorted_flag(IsSorted::Not);
258    list_ca_right.set_inner_dtype(cat_right.dtype().clone());
259    Ok((list_ca_left, list_ca_right))
260}