1use std::borrow::Cow;
2
3use arrow::bitmap::{Bitmap, BitmapBuilder};
4use arrow::compute::utils::{combine_validities_and, combine_validities_and_not};
5use polars_compute::if_then_else::{IfThenElseKernel, if_then_else_validity};
6
7#[cfg(feature = "object")]
8use crate::chunked_array::object::ObjectArray;
9use crate::prelude::*;
10use crate::utils::{align_chunks_binary, align_chunks_ternary};
11
12const SHAPE_MISMATCH_STR: &str =
13 "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation";
14
15fn if_then_else_broadcast_mask<T: PolarsDataType>(
16 mask: bool,
17 if_true: &ChunkedArray<T>,
18 if_false: &ChunkedArray<T>,
19) -> PolarsResult<ChunkedArray<T>>
20where
21 ChunkedArray<T>: ChunkExpandAtIndex<T>,
22{
23 let src = if mask { if_true } else { if_false };
24 let other = if mask { if_false } else { if_true };
25 let ret = match (src.len(), other.len()) {
26 (a, b) if a == b => src.clone(),
27 (_, 1) => src.clone(),
28 (1, other_len) => src.new_from_index(0, other_len),
29 _ => polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR),
30 };
31 Ok(ret.with_name(if_true.name().clone()))
32}
33
34fn bool_null_to_false(mask: &BooleanArray) -> Bitmap {
35 if mask.null_count() == 0 {
36 mask.values().clone()
37 } else {
38 mask.values() & mask.validity().unwrap()
39 }
40}
41
42fn combine_validities_chunked<
46 T: PolarsDataType,
47 F: Fn(Option<&Bitmap>, Option<&Bitmap>) -> Option<Bitmap>,
48>(
49 ca: &ChunkedArray<T>,
50 mask: &BooleanChunked,
51 combiner: F,
52) -> ChunkedArray<T> {
53 let (ca_al, mask_al) = align_chunks_binary(ca, mask);
54 let chunks = ca_al
55 .downcast_iter()
56 .zip(mask_al.downcast_iter())
57 .map(|(a, m)| {
58 let bm = bool_null_to_false(m);
59 let validity = combiner(a.validity(), Some(&bm));
60 a.clone().with_validity_typed(validity)
61 });
62 ChunkedArray::from_chunk_iter_like(ca, chunks)
63}
64
65impl<T> ChunkZip<T> for ChunkedArray<T>
66where
67 T: PolarsDataType<IsStruct = FalseT>,
68 T::Array: for<'a> IfThenElseKernel<Scalar<'a> = T::Physical<'a>>,
69 ChunkedArray<T>: ChunkExpandAtIndex<T>,
70{
71 fn zip_with(
72 &self,
73 mask: &BooleanChunked,
74 other: &ChunkedArray<T>,
75 ) -> PolarsResult<ChunkedArray<T>> {
76 let if_true = self;
77 let if_false = other;
78
79 if mask.len() == 1 {
81 return if_then_else_broadcast_mask(mask.get(0).unwrap_or(false), if_true, if_false);
82 }
83
84 let ret = if if_true.len() == 1 && if_false.len() == 1 {
86 match (if_true.get(0), if_false.get(0)) {
87 (None, None) => ChunkedArray::full_null_like(if_true, mask.len()),
88 (None, Some(_)) => combine_validities_chunked(
89 &if_false.new_from_index(0, mask.len()),
90 mask,
91 combine_validities_and_not,
92 ),
93 (Some(_), None) => combine_validities_chunked(
94 &if_true.new_from_index(0, mask.len()),
95 mask,
96 combine_validities_and,
97 ),
98 (Some(t), Some(f)) => {
99 let dtype = if_true.downcast_iter().next().unwrap().dtype();
100 let chunks = mask.downcast_iter().map(|m| {
101 let bm = bool_null_to_false(m);
102 let t = t.clone();
103 let f = f.clone();
104 IfThenElseKernel::if_then_else_broadcast_both(dtype.clone(), &bm, t, f)
105 });
106 ChunkedArray::from_chunk_iter_like(if_true, chunks)
107 },
108 }
109
110 } else if if_true.len() == if_false.len() {
112 polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR);
113 let (mask_al, if_true_al, if_false_al) = align_chunks_ternary(mask, if_true, if_false);
114 let chunks = mask_al
115 .downcast_iter()
116 .zip(if_true_al.downcast_iter())
117 .zip(if_false_al.downcast_iter())
118 .map(|((m, t), f)| IfThenElseKernel::if_then_else(&bool_null_to_false(m), t, f));
119 ChunkedArray::from_chunk_iter_like(if_true, chunks)
120
121 } else if if_true.len() == 1 {
123 polars_ensure!(mask.len() == if_false.len(), ShapeMismatch: SHAPE_MISMATCH_STR);
124 if let Some(true_scalar) = if_true.get(0) {
125 let (mask_al, if_false_al) = align_chunks_binary(mask, if_false);
126 let chunks = mask_al
127 .downcast_iter()
128 .zip(if_false_al.downcast_iter())
129 .map(|(m, f)| {
130 let bm = bool_null_to_false(m);
131 let t = true_scalar.clone();
132 IfThenElseKernel::if_then_else_broadcast_true(&bm, t, f)
133 });
134 ChunkedArray::from_chunk_iter_like(if_true, chunks)
135 } else {
136 combine_validities_chunked(if_false, mask, combine_validities_and_not)
137 }
138
139 } else if if_false.len() == 1 {
141 polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR);
142 if let Some(false_scalar) = if_false.get(0) {
143 let (mask_al, if_true_al) = align_chunks_binary(mask, if_true);
144 let chunks =
145 mask_al
146 .downcast_iter()
147 .zip(if_true_al.downcast_iter())
148 .map(|(m, t)| {
149 let bm = bool_null_to_false(m);
150 let f = false_scalar.clone();
151 IfThenElseKernel::if_then_else_broadcast_false(&bm, t, f)
152 });
153 ChunkedArray::from_chunk_iter_like(if_false, chunks)
154 } else {
155 combine_validities_chunked(if_true, mask, combine_validities_and)
156 }
157 } else {
158 polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR)
159 };
160
161 Ok(ret.with_name(if_true.name().clone()))
162 }
163}
164
165#[cfg(feature = "object")]
167impl<T: PolarsObject> IfThenElseKernel for ObjectArray<T> {
168 type Scalar<'a> = &'a T;
169
170 fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
171 mask.iter()
172 .zip(if_true.iter())
173 .zip(if_false.iter())
174 .map(|((m, t), f)| if m { t } else { f })
175 .collect_arr()
176 }
177
178 fn if_then_else_broadcast_true(
179 mask: &Bitmap,
180 if_true: Self::Scalar<'_>,
181 if_false: &Self,
182 ) -> Self {
183 mask.iter()
184 .zip(if_false.iter())
185 .map(|(m, f)| if m { Some(if_true) } else { f })
186 .collect_arr()
187 }
188
189 fn if_then_else_broadcast_false(
190 mask: &Bitmap,
191 if_true: &Self,
192 if_false: Self::Scalar<'_>,
193 ) -> Self {
194 mask.iter()
195 .zip(if_true.iter())
196 .map(|(m, t)| if m { t } else { Some(if_false) })
197 .collect_arr()
198 }
199
200 fn if_then_else_broadcast_both(
201 _dtype: ArrowDataType,
202 mask: &Bitmap,
203 if_true: Self::Scalar<'_>,
204 if_false: Self::Scalar<'_>,
205 ) -> Self {
206 mask.iter()
207 .map(|m| if m { if_true } else { if_false })
208 .collect_arr()
209 }
210}
211
212#[cfg(feature = "dtype-struct")]
213impl ChunkZip<StructType> for StructChunked {
214 fn zip_with(
215 &self,
216 mask: &BooleanChunked,
217 other: &ChunkedArray<StructType>,
218 ) -> PolarsResult<ChunkedArray<StructType>> {
219 let min_length = self.length.min(mask.length).min(other.length);
220 let max_length = self.length.max(mask.length).max(other.length);
221
222 let length = if min_length == 0 { 0 } else { max_length };
223
224 debug_assert!(self.length == 1 || self.length == length);
225 debug_assert!(mask.length == 1 || mask.length == length);
226 debug_assert!(other.length == 1 || other.length == length);
227
228 let mut if_true: Cow<ChunkedArray<StructType>> = Cow::Borrowed(self);
229 let mut if_false: Cow<ChunkedArray<StructType>> = Cow::Borrowed(other);
230
231 if mask.length == 1 {
234 let is_true = mask.get(0).unwrap_or(false);
236 return Ok(if is_true && self.length == 1 {
237 self.new_from_index(0, length)
238 } else if is_true {
239 self.clone()
240 } else if other.length == 1 {
241 let mut s = other.new_from_index(0, length);
242 s.rename(self.name().clone());
243 s
244 } else {
245 let mut s = other.clone();
246 s.rename(self.name().clone());
247 s
248 });
249 }
250
251 let needs_broadcast =
257 if_true.chunks().len() > 1 || if_false.chunks().len() > 1 || mask.chunks().len() > 1;
258 if needs_broadcast && length > 1 {
259 if self.length == 1 {
260 let broadcasted = self.new_from_index(0, length);
261 if_true = Cow::Owned(broadcasted);
262 }
263 if other.length == 1 {
264 let broadcasted = other.new_from_index(0, length);
265 if_false = Cow::Owned(broadcasted);
266 }
267 }
268
269 let if_true = if_true.as_ref();
270 let if_false = if_false.as_ref();
271
272 let (if_true, if_false, mask) = align_chunks_ternary(if_true, if_false, mask);
273
274 let mut mask = mask.into_owned();
279 unsafe {
280 for arr in mask.downcast_iter_mut() {
281 let bm = bool_null_to_false(arr);
282 *arr = BooleanArray::from_data_default(bm, None);
283 }
284 mask.set_null_count(0);
285 }
286
287 let fields = if_true
289 .fields_as_series()
290 .iter()
291 .zip(if_false.fields_as_series())
292 .map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs))
293 .collect::<PolarsResult<Vec<_>>>()?;
294
295 let mut out = StructChunked::from_series(self.name().clone(), length, fields.iter())?;
296
297 fn rechunk_bitmaps(
298 total_length: usize,
299 iter: impl Iterator<Item = (usize, Option<Bitmap>)>,
300 ) -> Option<Bitmap> {
301 let mut rechunked_length = 0;
302 let mut rechunked_validity = None;
303 for (chunk_length, validity) in iter {
304 if let Some(validity) = validity {
305 if validity.unset_bits() > 0 {
306 rechunked_validity
307 .get_or_insert_with(|| {
308 let mut bm = BitmapBuilder::with_capacity(total_length);
309 bm.extend_constant(rechunked_length, true);
310 bm
311 })
312 .extend_from_bitmap(&validity);
313 }
314 }
315
316 rechunked_length += chunk_length;
317 }
318
319 if let Some(rechunked_validity) = rechunked_validity.as_mut() {
320 rechunked_validity.extend_constant(total_length - rechunked_validity.len(), true);
321 }
322
323 rechunked_validity.map(BitmapBuilder::freeze)
324 }
325
326 if (if_true.null_count + if_false.null_count) > 0 {
332 let rechunked_validity = match (if_true.len(), if_false.len()) {
334 (1, 1) if length != 1 => {
335 match (if_true.null_count() == 0, if_false.null_count() == 0) {
336 (true, true) => None,
337 (false, true) => {
338 if mask.chunks().len() == 1 {
339 let m = mask.chunks()[0]
340 .as_any()
341 .downcast_ref::<BooleanArray>()
342 .unwrap()
343 .values();
344 Some(!m)
345 } else {
346 rechunk_bitmaps(
347 length,
348 mask.downcast_iter()
349 .map(|m| (m.len(), Some(m.values().clone()))),
350 )
351 }
352 },
353 (true, false) => {
354 if mask.chunks().len() == 1 {
355 let m = mask.chunks()[0]
356 .as_any()
357 .downcast_ref::<BooleanArray>()
358 .unwrap()
359 .values();
360 Some(m.clone())
361 } else {
362 rechunk_bitmaps(
363 length,
364 mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))),
365 )
366 }
367 },
368 (false, false) => Some(Bitmap::new_zeroed(length)),
369 }
370 },
371 (1, _) if length != 1 => {
372 debug_assert!(
373 if_false
374 .chunk_lengths()
375 .zip(mask.chunk_lengths())
376 .all(|(r, m)| r == m)
377 );
378
379 let combine = if if_true.null_count() == 0 {
380 |if_false: Option<&Bitmap>, m: &Bitmap| {
381 if_false.map(|v| arrow::bitmap::or(v, m))
382 }
383 } else {
384 |if_false: Option<&Bitmap>, m: &Bitmap| {
385 Some(if_false.map_or_else(|| !m, |v| arrow::bitmap::and_not(v, m)))
386 }
387 };
388
389 if if_false.chunks().len() == 1 {
390 let if_false = if_false.chunks()[0].validity();
391 let m = mask.chunks()[0]
392 .as_any()
393 .downcast_ref::<BooleanArray>()
394 .unwrap()
395 .values();
396
397 let validity = combine(if_false, m);
398 validity.filter(|v| v.unset_bits() > 0)
399 } else {
400 rechunk_bitmaps(
401 length,
402 if_false.chunks().iter().zip(mask.downcast_iter()).map(
403 |(chunk, mask)| {
404 (mask.len(), combine(chunk.validity(), mask.values()))
405 },
406 ),
407 )
408 }
409 },
410 (_, 1) if length != 1 => {
411 debug_assert!(
412 if_true
413 .chunk_lengths()
414 .zip(mask.chunk_lengths())
415 .all(|(l, m)| l == m)
416 );
417
418 let combine = if if_false.null_count() == 0 {
419 |if_true: Option<&Bitmap>, m: &Bitmap| {
420 if_true.map(|v| arrow::bitmap::or_not(v, m))
421 }
422 } else {
423 |if_true: Option<&Bitmap>, m: &Bitmap| {
424 Some(if_true.map_or_else(|| m.clone(), |v| arrow::bitmap::and(v, m)))
425 }
426 };
427
428 if if_true.chunks().len() == 1 {
429 let if_true = if_true.chunks()[0].validity();
430 let m = mask.chunks()[0]
431 .as_any()
432 .downcast_ref::<BooleanArray>()
433 .unwrap()
434 .values();
435
436 let validity = combine(if_true, m);
437 validity.filter(|v| v.unset_bits() > 0)
438 } else {
439 rechunk_bitmaps(
440 length,
441 if_true.chunks().iter().zip(mask.downcast_iter()).map(
442 |(chunk, mask)| {
443 (mask.len(), combine(chunk.validity(), mask.values()))
444 },
445 ),
446 )
447 }
448 },
449 (_, _) => {
450 debug_assert!(
451 if_true
452 .chunk_lengths()
453 .zip(if_false.chunk_lengths())
454 .all(|(l, r)| l == r)
455 );
456 debug_assert!(
457 if_true
458 .chunk_lengths()
459 .zip(mask.chunk_lengths())
460 .all(|(l, r)| l == r)
461 );
462
463 let validities = if_true
464 .chunks()
465 .iter()
466 .zip(if_false.chunks())
467 .map(|(l, r)| (l.validity(), r.validity()));
468
469 rechunk_bitmaps(
470 length,
471 validities
472 .zip(mask.downcast_iter())
473 .map(|((if_true, if_false), mask)| {
474 (
475 mask.len(),
476 if_then_else_validity(mask.values(), if_true, if_false),
477 )
478 }),
479 )
480 },
481 };
482
483 if let Some(mut rechunked_validity) = rechunked_validity {
485 assert_eq!(rechunked_validity.len(), out.len());
486
487 let num_chunks = out.chunks().len();
488 let null_count = rechunked_validity.unset_bits();
489
490 let chunks = unsafe { out.chunks_mut() };
493
494 if num_chunks == 1 {
495 chunks[0] = chunks[0].with_validity(Some(rechunked_validity));
496 } else {
497 for chunk in chunks {
498 let chunk_len = chunk.len();
499 let chunk_validity;
500
501 (chunk_validity, rechunked_validity) =
503 unsafe { rechunked_validity.split_at_unchecked(chunk_len) };
504 *chunk = chunk.with_validity(
505 (chunk_validity.unset_bits() > 0).then_some(chunk_validity),
506 );
507 }
508 }
509
510 out.null_count = null_count;
511 } else {
512 let chunks = unsafe { out.chunks_mut() };
515
516 for chunk in chunks {
517 *chunk = chunk.with_validity(None);
518 }
519
520 out.null_count = 0;
521 }
522 }
523
524 if cfg!(debug_assertions) {
525 let start_length = out.len();
526 let start_null_count = out.null_count();
527
528 out.compute_len();
529
530 assert_eq!(start_length, out.len());
531 assert_eq!(start_null_count, out.null_count());
532 }
533 Ok(out)
534 }
535}