polars_core/chunked_array/logical/categorical/
merge.rs1use std::borrow::Cow;
2
3use super::*;
4use crate::series::IsSorted;
5use crate::utils::align_chunks_binary;
6
7fn slots_to_mut(slots: &Utf8ViewArray) -> MutablePlString {
8 slots.clone().make_mut()
9}
10
11struct State {
12 map: PlHashMap<u32, u32>,
13 slots: MutablePlString,
14}
15
16#[derive(Default)]
17pub struct GlobalRevMapMerger {
18 id: u32,
19 original: Arc<RevMapping>,
20 state: Option<State>,
24}
25
26impl GlobalRevMapMerger {
27 pub fn new(rev_map: Arc<RevMapping>) -> Self {
28 let RevMapping::Global(_, _, id) = rev_map.as_ref() else {
29 unreachable!()
30 };
31
32 GlobalRevMapMerger {
33 state: None,
34 id: *id,
35 original: rev_map,
36 }
37 }
38
39 fn init_state(&mut self) {
40 let RevMapping::Global(map, slots, _) = self.original.as_ref() else {
41 unreachable!()
42 };
43 self.state = Some(State {
44 map: (*map).clone(),
45 slots: slots_to_mut(slots),
46 })
47 }
48
49 pub fn merge_map(&mut self, rev_map: &Arc<RevMapping>) -> PolarsResult<()> {
50 if Arc::ptr_eq(&self.original, rev_map) {
52 return Ok(());
53 }
54
55 let RevMapping::Global(map, slots, id) = rev_map.as_ref() else {
56 polars_bail!(string_cache_mismatch)
57 };
58 polars_ensure!(*id == self.id, string_cache_mismatch);
59
60 if self.state.is_none() {
61 self.init_state()
62 }
63 let state = self.state.as_mut().unwrap();
64
65 for (cat, idx) in map.iter() {
66 state.map.entry(*cat).or_insert_with(|| {
67 let str_val = unsafe { slots.value_unchecked(*idx as usize) };
70 let new_idx = state.slots.len() as u32;
71 state.slots.push(Some(str_val));
72
73 new_idx
74 });
75 }
76 Ok(())
77 }
78
79 pub fn finish(self) -> Arc<RevMapping> {
80 match self.state {
81 None => self.original,
82 Some(state) => {
83 let new_rev = RevMapping::Global(state.map, state.slots.into(), self.id);
84 Arc::new(new_rev)
85 },
86 }
87 }
88}
89
90fn merge_local_rhs_categorical<'a>(
91 categories: &'a Utf8ViewArray,
92 ca_right: &'a CategoricalChunked,
93) -> Result<(UInt32Chunked, Arc<RevMapping>), PolarsError> {
94 polars_warn!(
98 CategoricalRemappingWarning,
99 "Local categoricals have different encodings, expensive re-encoding is done \
100 to perform this merge operation. Consider using a StringCache or an Enum type \
101 if the categories are known in advance"
102 );
103
104 let RevMapping::Local(cats_right, _) = &**ca_right.get_rev_map() else {
105 unreachable!()
106 };
107
108 let cats_left_hashmap = PlHashMap::from_iter(
109 categories
110 .values_iter()
111 .enumerate()
112 .map(|(k, v)| (v, k as u32)),
113 );
114 let mut new_categories = slots_to_mut(categories);
115 let mut idx_mapping = PlHashMap::with_capacity(cats_right.len());
116
117 for (idx, s) in cats_right.values_iter().enumerate() {
118 if let Some(v) = cats_left_hashmap.get(&s) {
119 idx_mapping.insert(idx as u32, *v);
120 } else {
121 idx_mapping.insert(idx as u32, new_categories.len() as u32);
122 new_categories.push(Some(s));
123 }
124 }
125 let new_rev_map = Arc::new(RevMapping::build_local(new_categories.into()));
126 Ok((
127 ca_right
128 .physical
129 .apply(|opt_v| opt_v.map(|v| *idx_mapping.get(&v).unwrap())),
130 new_rev_map,
131 ))
132}
133
134pub trait CategoricalMergeOperation {
135 fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult<UInt32Chunked>;
136}
137
138pub fn call_categorical_merge_operation<I: CategoricalMergeOperation>(
140 cat_left: &CategoricalChunked,
141 cat_right: &CategoricalChunked,
142 merge_ops: I,
143) -> PolarsResult<CategoricalChunked> {
144 let rev_map_left = cat_left.get_rev_map();
145 let rev_map_right = cat_right.get_rev_map();
146 let (mut new_physical, new_rev_map) = match (&**rev_map_left, &**rev_map_right) {
147 (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => {
148 let mut rev_map_merger = GlobalRevMapMerger::new(rev_map_left.clone());
149 rev_map_merger.merge_map(rev_map_right)?;
150 (
151 merge_ops.finish(cat_left.physical(), cat_right.physical())?,
152 rev_map_merger.finish(),
153 )
154 },
155 (RevMapping::Local(_, idl), RevMapping::Local(_, idr))
156 if idl == idr && cat_left.is_enum() == cat_right.is_enum() =>
157 {
158 (
159 merge_ops.finish(cat_left.physical(), cat_right.physical())?,
160 rev_map_left.clone(),
161 )
162 },
163 (RevMapping::Local(categorical, _), RevMapping::Local(_, _))
164 if !cat_left.is_enum() && !cat_right.is_enum() =>
165 {
166 let (rhs_physical, rev_map) = merge_local_rhs_categorical(categorical, cat_right)?;
167 (
168 merge_ops.finish(cat_left.physical(), &rhs_physical)?,
169 rev_map,
170 )
171 },
172 (RevMapping::Local(_, _), RevMapping::Local(_, _))
173 if cat_left.is_enum() | cat_right.is_enum() =>
174 {
175 polars_bail!(ComputeError: "can not merge incompatible Enum types")
176 },
177 _ => polars_bail!(string_cache_mismatch),
178 };
179 if cat_left.uses_lexical_ordering() {
182 new_physical.set_sorted_flag(IsSorted::Not)
183 }
184
185 unsafe {
187 Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
188 new_physical,
189 new_rev_map,
190 cat_left.is_enum(),
191 cat_left.get_ordering(),
192 ))
193 }
194}
195
196struct DoNothing;
197impl CategoricalMergeOperation for DoNothing {
198 fn finish(self, _lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult<UInt32Chunked> {
199 Ok(rhs.clone())
200 }
201}
202
203pub fn make_rhs_categoricals_compatible(
205 ca_left: &CategoricalChunked,
206 ca_right: &CategoricalChunked,
207) -> PolarsResult<(CategoricalChunked, CategoricalChunked)> {
208 let new_ca_right = call_categorical_merge_operation(ca_left, ca_right, DoNothing)?;
209
210 let mut new_ca_left = ca_left.clone();
212 unsafe {
214 new_ca_left.set_rev_map(
215 new_ca_right.get_rev_map().clone(),
216 ca_left.get_rev_map().len() == new_ca_right.get_rev_map().len(),
217 )
218 };
219
220 Ok((new_ca_left, new_ca_right))
221}
222
223pub fn make_rhs_list_categoricals_compatible(
224 mut list_ca_left: ListChunked,
225 list_ca_right: ListChunked,
226) -> PolarsResult<(ListChunked, ListChunked)> {
227 let cat_left = list_ca_left.get_inner();
230 let cat_right = list_ca_right.get_inner();
231 let (cat_left, cat_right) =
232 make_rhs_categoricals_compatible(cat_left.categorical()?, cat_right.categorical()?)?;
233
234 list_ca_left.set_inner_dtype(cat_left.dtype().clone());
236
237 let (list_ca_right, cat_physical): (Cow<ListChunked>, Cow<UInt32Chunked>) =
239 align_chunks_binary(&list_ca_right, cat_right.physical());
240 let mut list_ca_right = list_ca_right.into_owned();
241 unsafe {
244 list_ca_right
245 .downcast_iter_mut()
246 .zip(cat_physical.chunks())
247 .for_each(|(arr, new_phys)| {
248 *arr = ListArray::new(
249 arr.dtype().clone(),
250 arr.offsets().clone(),
251 new_phys.clone(),
252 arr.validity().cloned(),
253 )
254 });
255 }
256 list_ca_right.set_sorted_flag(IsSorted::Not);
258 list_ca_right.set_inner_dtype(cat_right.dtype().clone());
259 Ok((list_ca_left, list_ca_right))
260}