polars_core/chunked_array/ops/
zip.rs

1use std::borrow::Cow;
2
3use arrow::bitmap::{Bitmap, BitmapBuilder};
4use arrow::compute::utils::{combine_validities_and, combine_validities_and_not};
5use polars_compute::if_then_else::{IfThenElseKernel, if_then_else_validity};
6
7#[cfg(feature = "object")]
8use crate::chunked_array::object::ObjectArray;
9use crate::prelude::*;
10use crate::utils::{align_chunks_binary, align_chunks_ternary};
11
12const SHAPE_MISMATCH_STR: &str =
13    "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation";
14
15fn if_then_else_broadcast_mask<T: PolarsDataType>(
16    mask: bool,
17    if_true: &ChunkedArray<T>,
18    if_false: &ChunkedArray<T>,
19) -> PolarsResult<ChunkedArray<T>>
20where
21    ChunkedArray<T>: ChunkExpandAtIndex<T>,
22{
23    let src = if mask { if_true } else { if_false };
24    let other = if mask { if_false } else { if_true };
25    let ret = match (src.len(), other.len()) {
26        (a, b) if a == b => src.clone(),
27        (_, 1) => src.clone(),
28        (1, other_len) => src.new_from_index(0, other_len),
29        _ => polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR),
30    };
31    Ok(ret.with_name(if_true.name().clone()))
32}
33
34fn bool_null_to_false(mask: &BooleanArray) -> Bitmap {
35    if mask.null_count() == 0 {
36        mask.values().clone()
37    } else {
38        mask.values() & mask.validity().unwrap()
39    }
40}
41
42/// Combines the validities of ca with the bits in mask using the given combiner.
43///
44/// If the mask itself has validity, those null bits are converted to false.
45fn combine_validities_chunked<
46    T: PolarsDataType,
47    F: Fn(Option<&Bitmap>, Option<&Bitmap>) -> Option<Bitmap>,
48>(
49    ca: &ChunkedArray<T>,
50    mask: &BooleanChunked,
51    combiner: F,
52) -> ChunkedArray<T> {
53    let (ca_al, mask_al) = align_chunks_binary(ca, mask);
54    let chunks = ca_al
55        .downcast_iter()
56        .zip(mask_al.downcast_iter())
57        .map(|(a, m)| {
58            let bm = bool_null_to_false(m);
59            let validity = combiner(a.validity(), Some(&bm));
60            a.clone().with_validity_typed(validity)
61        });
62    ChunkedArray::from_chunk_iter_like(ca, chunks)
63}
64
65impl<T> ChunkZip<T> for ChunkedArray<T>
66where
67    T: PolarsDataType<IsStruct = FalseT>,
68    T::Array: for<'a> IfThenElseKernel<Scalar<'a> = T::Physical<'a>>,
69    ChunkedArray<T>: ChunkExpandAtIndex<T>,
70{
71    fn zip_with(
72        &self,
73        mask: &BooleanChunked,
74        other: &ChunkedArray<T>,
75    ) -> PolarsResult<ChunkedArray<T>> {
76        let if_true = self;
77        let if_false = other;
78
79        // Broadcast mask.
80        if mask.len() == 1 {
81            return if_then_else_broadcast_mask(mask.get(0).unwrap_or(false), if_true, if_false);
82        }
83
84        // Broadcast both.
85        let ret = if if_true.len() == 1 && if_false.len() == 1 {
86            match (if_true.get(0), if_false.get(0)) {
87                (None, None) => ChunkedArray::full_null_like(if_true, mask.len()),
88                (None, Some(_)) => combine_validities_chunked(
89                    &if_false.new_from_index(0, mask.len()),
90                    mask,
91                    combine_validities_and_not,
92                ),
93                (Some(_), None) => combine_validities_chunked(
94                    &if_true.new_from_index(0, mask.len()),
95                    mask,
96                    combine_validities_and,
97                ),
98                (Some(t), Some(f)) => {
99                    let dtype = if_true.downcast_iter().next().unwrap().dtype();
100                    let chunks = mask.downcast_iter().map(|m| {
101                        let bm = bool_null_to_false(m);
102                        let t = t.clone();
103                        let f = f.clone();
104                        IfThenElseKernel::if_then_else_broadcast_both(dtype.clone(), &bm, t, f)
105                    });
106                    ChunkedArray::from_chunk_iter_like(if_true, chunks)
107                },
108            }
109
110        // Broadcast neither.
111        } else if if_true.len() == if_false.len() {
112            polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR);
113            let (mask_al, if_true_al, if_false_al) = align_chunks_ternary(mask, if_true, if_false);
114            let chunks = mask_al
115                .downcast_iter()
116                .zip(if_true_al.downcast_iter())
117                .zip(if_false_al.downcast_iter())
118                .map(|((m, t), f)| IfThenElseKernel::if_then_else(&bool_null_to_false(m), t, f));
119            ChunkedArray::from_chunk_iter_like(if_true, chunks)
120
121        // Broadcast true value.
122        } else if if_true.len() == 1 {
123            polars_ensure!(mask.len() == if_false.len(), ShapeMismatch: SHAPE_MISMATCH_STR);
124            if let Some(true_scalar) = if_true.get(0) {
125                let (mask_al, if_false_al) = align_chunks_binary(mask, if_false);
126                let chunks = mask_al
127                    .downcast_iter()
128                    .zip(if_false_al.downcast_iter())
129                    .map(|(m, f)| {
130                        let bm = bool_null_to_false(m);
131                        let t = true_scalar.clone();
132                        IfThenElseKernel::if_then_else_broadcast_true(&bm, t, f)
133                    });
134                ChunkedArray::from_chunk_iter_like(if_true, chunks)
135            } else {
136                combine_validities_chunked(if_false, mask, combine_validities_and_not)
137            }
138
139        // Broadcast false value.
140        } else if if_false.len() == 1 {
141            polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR);
142            if let Some(false_scalar) = if_false.get(0) {
143                let (mask_al, if_true_al) = align_chunks_binary(mask, if_true);
144                let chunks =
145                    mask_al
146                        .downcast_iter()
147                        .zip(if_true_al.downcast_iter())
148                        .map(|(m, t)| {
149                            let bm = bool_null_to_false(m);
150                            let f = false_scalar.clone();
151                            IfThenElseKernel::if_then_else_broadcast_false(&bm, t, f)
152                        });
153                ChunkedArray::from_chunk_iter_like(if_false, chunks)
154            } else {
155                combine_validities_chunked(if_true, mask, combine_validities_and)
156            }
157        } else {
158            polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR)
159        };
160
161        Ok(ret.with_name(if_true.name().clone()))
162    }
163}
164
165// Basic implementation for ObjectArray.
166#[cfg(feature = "object")]
167impl<T: PolarsObject> IfThenElseKernel for ObjectArray<T> {
168    type Scalar<'a> = &'a T;
169
170    fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
171        mask.iter()
172            .zip(if_true.iter())
173            .zip(if_false.iter())
174            .map(|((m, t), f)| if m { t } else { f })
175            .collect_arr()
176    }
177
178    fn if_then_else_broadcast_true(
179        mask: &Bitmap,
180        if_true: Self::Scalar<'_>,
181        if_false: &Self,
182    ) -> Self {
183        mask.iter()
184            .zip(if_false.iter())
185            .map(|(m, f)| if m { Some(if_true) } else { f })
186            .collect_arr()
187    }
188
189    fn if_then_else_broadcast_false(
190        mask: &Bitmap,
191        if_true: &Self,
192        if_false: Self::Scalar<'_>,
193    ) -> Self {
194        mask.iter()
195            .zip(if_true.iter())
196            .map(|(m, t)| if m { t } else { Some(if_false) })
197            .collect_arr()
198    }
199
200    fn if_then_else_broadcast_both(
201        _dtype: ArrowDataType,
202        mask: &Bitmap,
203        if_true: Self::Scalar<'_>,
204        if_false: Self::Scalar<'_>,
205    ) -> Self {
206        mask.iter()
207            .map(|m| if m { if_true } else { if_false })
208            .collect_arr()
209    }
210}
211
212#[cfg(feature = "dtype-struct")]
213impl ChunkZip<StructType> for StructChunked {
214    fn zip_with(
215        &self,
216        mask: &BooleanChunked,
217        other: &ChunkedArray<StructType>,
218    ) -> PolarsResult<ChunkedArray<StructType>> {
219        let min_length = self.length.min(mask.length).min(other.length);
220        let max_length = self.length.max(mask.length).max(other.length);
221
222        let length = if min_length == 0 { 0 } else { max_length };
223
224        debug_assert!(self.length == 1 || self.length == length);
225        debug_assert!(mask.length == 1 || mask.length == length);
226        debug_assert!(other.length == 1 || other.length == length);
227
228        let mut if_true: Cow<ChunkedArray<StructType>> = Cow::Borrowed(self);
229        let mut if_false: Cow<ChunkedArray<StructType>> = Cow::Borrowed(other);
230
231        // Special case. In this case, we know what to do.
232        // @TODO: Optimization. If all mask values are the same, select one of the two.
233        if mask.length == 1 {
234            // pl.when(None) <=> pl.when(False)
235            let is_true = mask.get(0).unwrap_or(false);
236            return Ok(if is_true && self.length == 1 {
237                self.new_from_index(0, length)
238            } else if is_true {
239                self.clone()
240            } else if other.length == 1 {
241                let mut s = other.new_from_index(0, length);
242                s.rename(self.name().clone());
243                s
244            } else {
245                let mut s = other.clone();
246                s.rename(self.name().clone());
247                s
248            });
249        }
250
251        // align_chunks_ternary can only align chunks if:
252        // - Each chunkedarray only has 1 chunk
253        // - Each chunkedarray has an equal length (i.e. is broadcasted)
254        //
255        // Therefore, we broadcast only those that are necessary to be broadcasted.
256        let needs_broadcast =
257            if_true.chunks().len() > 1 || if_false.chunks().len() > 1 || mask.chunks().len() > 1;
258        if needs_broadcast && length > 1 {
259            if self.length == 1 {
260                let broadcasted = self.new_from_index(0, length);
261                if_true = Cow::Owned(broadcasted);
262            }
263            if other.length == 1 {
264                let broadcasted = other.new_from_index(0, length);
265                if_false = Cow::Owned(broadcasted);
266            }
267        }
268
269        let if_true = if_true.as_ref();
270        let if_false = if_false.as_ref();
271
272        let (if_true, if_false, mask) = align_chunks_ternary(if_true, if_false, mask);
273
274        // Prepare the boolean arrays such that Null maps to false.
275        // This prevents every field doing that.
276        // # SAFETY
277        // We don't modify the length and update the null count.
278        let mut mask = mask.into_owned();
279        unsafe {
280            for arr in mask.downcast_iter_mut() {
281                let bm = bool_null_to_false(arr);
282                *arr = BooleanArray::from_data_default(bm, None);
283            }
284            mask.set_null_count(0);
285        }
286
287        // Zip all the fields.
288        let fields = if_true
289            .fields_as_series()
290            .iter()
291            .zip(if_false.fields_as_series())
292            .map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs))
293            .collect::<PolarsResult<Vec<_>>>()?;
294
295        let mut out = StructChunked::from_series(self.name().clone(), length, fields.iter())?;
296
297        fn rechunk_bitmaps(
298            total_length: usize,
299            iter: impl Iterator<Item = (usize, Option<Bitmap>)>,
300        ) -> Option<Bitmap> {
301            let mut rechunked_length = 0;
302            let mut rechunked_validity = None;
303            for (chunk_length, validity) in iter {
304                if let Some(validity) = validity {
305                    if validity.unset_bits() > 0 {
306                        rechunked_validity
307                            .get_or_insert_with(|| {
308                                let mut bm = BitmapBuilder::with_capacity(total_length);
309                                bm.extend_constant(rechunked_length, true);
310                                bm
311                            })
312                            .extend_from_bitmap(&validity);
313                    }
314                }
315
316                rechunked_length += chunk_length;
317            }
318
319            if let Some(rechunked_validity) = rechunked_validity.as_mut() {
320                rechunked_validity.extend_constant(total_length - rechunked_validity.len(), true);
321            }
322
323            rechunked_validity.map(BitmapBuilder::freeze)
324        }
325
326        // Zip the validities.
327        //
328        // We need to take two things into account:
329        // 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`.
330        // 2. `l` and `r` might still need to be broadcasted.
331        if (if_true.null_count + if_false.null_count) > 0 {
332            // Create one validity mask that spans the entirety of out.
333            let rechunked_validity = match (if_true.len(), if_false.len()) {
334                (1, 1) if length != 1 => {
335                    match (if_true.null_count() == 0, if_false.null_count() == 0) {
336                        (true, true) => None,
337                        (false, true) => {
338                            if mask.chunks().len() == 1 {
339                                let m = mask.chunks()[0]
340                                    .as_any()
341                                    .downcast_ref::<BooleanArray>()
342                                    .unwrap()
343                                    .values();
344                                Some(!m)
345                            } else {
346                                rechunk_bitmaps(
347                                    length,
348                                    mask.downcast_iter()
349                                        .map(|m| (m.len(), Some(m.values().clone()))),
350                                )
351                            }
352                        },
353                        (true, false) => {
354                            if mask.chunks().len() == 1 {
355                                let m = mask.chunks()[0]
356                                    .as_any()
357                                    .downcast_ref::<BooleanArray>()
358                                    .unwrap()
359                                    .values();
360                                Some(m.clone())
361                            } else {
362                                rechunk_bitmaps(
363                                    length,
364                                    mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))),
365                                )
366                            }
367                        },
368                        (false, false) => Some(Bitmap::new_zeroed(length)),
369                    }
370                },
371                (1, _) if length != 1 => {
372                    debug_assert!(
373                        if_false
374                            .chunk_lengths()
375                            .zip(mask.chunk_lengths())
376                            .all(|(r, m)| r == m)
377                    );
378
379                    let combine = if if_true.null_count() == 0 {
380                        |if_false: Option<&Bitmap>, m: &Bitmap| {
381                            if_false.map(|v| arrow::bitmap::or(v, m))
382                        }
383                    } else {
384                        |if_false: Option<&Bitmap>, m: &Bitmap| {
385                            Some(if_false.map_or_else(|| !m, |v| arrow::bitmap::and_not(v, m)))
386                        }
387                    };
388
389                    if if_false.chunks().len() == 1 {
390                        let if_false = if_false.chunks()[0].validity();
391                        let m = mask.chunks()[0]
392                            .as_any()
393                            .downcast_ref::<BooleanArray>()
394                            .unwrap()
395                            .values();
396
397                        let validity = combine(if_false, m);
398                        validity.filter(|v| v.unset_bits() > 0)
399                    } else {
400                        rechunk_bitmaps(
401                            length,
402                            if_false.chunks().iter().zip(mask.downcast_iter()).map(
403                                |(chunk, mask)| {
404                                    (mask.len(), combine(chunk.validity(), mask.values()))
405                                },
406                            ),
407                        )
408                    }
409                },
410                (_, 1) if length != 1 => {
411                    debug_assert!(
412                        if_true
413                            .chunk_lengths()
414                            .zip(mask.chunk_lengths())
415                            .all(|(l, m)| l == m)
416                    );
417
418                    let combine = if if_false.null_count() == 0 {
419                        |if_true: Option<&Bitmap>, m: &Bitmap| {
420                            if_true.map(|v| arrow::bitmap::or_not(v, m))
421                        }
422                    } else {
423                        |if_true: Option<&Bitmap>, m: &Bitmap| {
424                            Some(if_true.map_or_else(|| m.clone(), |v| arrow::bitmap::and(v, m)))
425                        }
426                    };
427
428                    if if_true.chunks().len() == 1 {
429                        let if_true = if_true.chunks()[0].validity();
430                        let m = mask.chunks()[0]
431                            .as_any()
432                            .downcast_ref::<BooleanArray>()
433                            .unwrap()
434                            .values();
435
436                        let validity = combine(if_true, m);
437                        validity.filter(|v| v.unset_bits() > 0)
438                    } else {
439                        rechunk_bitmaps(
440                            length,
441                            if_true.chunks().iter().zip(mask.downcast_iter()).map(
442                                |(chunk, mask)| {
443                                    (mask.len(), combine(chunk.validity(), mask.values()))
444                                },
445                            ),
446                        )
447                    }
448                },
449                (_, _) => {
450                    debug_assert!(
451                        if_true
452                            .chunk_lengths()
453                            .zip(if_false.chunk_lengths())
454                            .all(|(l, r)| l == r)
455                    );
456                    debug_assert!(
457                        if_true
458                            .chunk_lengths()
459                            .zip(mask.chunk_lengths())
460                            .all(|(l, r)| l == r)
461                    );
462
463                    let validities = if_true
464                        .chunks()
465                        .iter()
466                        .zip(if_false.chunks())
467                        .map(|(l, r)| (l.validity(), r.validity()));
468
469                    rechunk_bitmaps(
470                        length,
471                        validities
472                            .zip(mask.downcast_iter())
473                            .map(|((if_true, if_false), mask)| {
474                                (
475                                    mask.len(),
476                                    if_then_else_validity(mask.values(), if_true, if_false),
477                                )
478                            }),
479                    )
480                },
481            };
482
483            // Apply the validity spreading over the chunks of out.
484            if let Some(mut rechunked_validity) = rechunked_validity {
485                assert_eq!(rechunked_validity.len(), out.len());
486
487                let num_chunks = out.chunks().len();
488                let null_count = rechunked_validity.unset_bits();
489
490                // SAFETY: We do not change the lengths of the chunks and we update the null_count
491                // afterwards.
492                let chunks = unsafe { out.chunks_mut() };
493
494                if num_chunks == 1 {
495                    chunks[0] = chunks[0].with_validity(Some(rechunked_validity));
496                } else {
497                    for chunk in chunks {
498                        let chunk_len = chunk.len();
499                        let chunk_validity;
500
501                        // SAFETY: We know that rechunked_validity.len() == out.len()
502                        (chunk_validity, rechunked_validity) =
503                            unsafe { rechunked_validity.split_at_unchecked(chunk_len) };
504                        *chunk = chunk.with_validity(
505                            (chunk_validity.unset_bits() > 0).then_some(chunk_validity),
506                        );
507                    }
508                }
509
510                out.null_count = null_count;
511            } else {
512                // SAFETY: We do not change the lengths of the chunks and we update the null_count
513                // afterwards.
514                let chunks = unsafe { out.chunks_mut() };
515
516                for chunk in chunks {
517                    *chunk = chunk.with_validity(None);
518                }
519
520                out.null_count = 0;
521            }
522        }
523
524        if cfg!(debug_assertions) {
525            let start_length = out.len();
526            let start_null_count = out.null_count();
527
528            out.compute_len();
529
530            assert_eq!(start_length, out.len());
531            assert_eq!(start_null_count, out.null_count());
532        }
533        Ok(out)
534    }
535}