polars_core/series/arithmetic/
fixed_size_list.rs

1use polars_error::{PolarsResult, feature_gated};
2
3use super::list_utils::NumericOp;
4use super::{ArrayChunked, FixedSizeListType, IntoSeries, NumOpsDispatchInner, Series};
5
6impl NumOpsDispatchInner for FixedSizeListType {
7    fn add_to(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
8        NumericFixedSizeListOp::add().execute(&lhs.clone().into_series(), rhs)
9    }
10
11    fn subtract(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
12        NumericFixedSizeListOp::sub().execute(&lhs.clone().into_series(), rhs)
13    }
14
15    fn multiply(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
16        NumericFixedSizeListOp::mul().execute(&lhs.clone().into_series(), rhs)
17    }
18
19    fn divide(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
20        NumericFixedSizeListOp::div().execute(&lhs.clone().into_series(), rhs)
21    }
22
23    fn remainder(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
24        NumericFixedSizeListOp::rem().execute(&lhs.clone().into_series(), rhs)
25    }
26}
27
28#[derive(Clone)]
29pub struct NumericFixedSizeListOp(NumericOp);
30
31impl NumericFixedSizeListOp {
32    pub fn add() -> Self {
33        Self(NumericOp::Add)
34    }
35
36    pub fn sub() -> Self {
37        Self(NumericOp::Sub)
38    }
39
40    pub fn mul() -> Self {
41        Self(NumericOp::Mul)
42    }
43
44    pub fn div() -> Self {
45        Self(NumericOp::Div)
46    }
47
48    pub fn rem() -> Self {
49        Self(NumericOp::Rem)
50    }
51
52    pub fn floor_div() -> Self {
53        Self(NumericOp::FloorDiv)
54    }
55}
56
57impl NumericFixedSizeListOp {
58    #[cfg_attr(not(feature = "array_arithmetic"), allow(unused))]
59    pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
60        feature_gated!("array_arithmetic", {
61            NumericFixedSizeListOpHelper::execute_op(self.clone(), lhs.rechunk(), rhs.rechunk())
62                .map(|x| x.into_series())
63        })
64    }
65}
66
67#[cfg(feature = "array_arithmetic")]
68use inner::NumericFixedSizeListOpHelper;
69
70#[cfg(feature = "array_arithmetic")]
71mod inner {
72    use arrow::bitmap::{Bitmap, BitmapBuilder};
73    use arrow::compute::utils::combine_validities_and;
74    use fixed_size_list::NumericFixedSizeListOp;
75    use list_utils::with_match_pl_num_arith;
76    use num_traits::Zero;
77    use polars_compute::arithmetic::pl_num::PlNumArithmetic;
78    use polars_utils::float::IsFloat;
79
80    use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
81    use super::super::*;
82
83    /// Utility to perform a binary operation between the primitive values of
84    /// 2 columns, where at least one of the columns is a `ArrayChunked` type.
85    pub(super) struct NumericFixedSizeListOpHelper {
86        op: NumericFixedSizeListOp,
87        output_name: PlSmallStr,
88        /// We are just re-using the enum used for list arithmetic.
89        op_apply_type: BinaryOpApplyType,
90        broadcast: Broadcast,
91        /// Stride of the leaf array
92        stride: usize,
93        /// Widths at every level
94        output_widths: Vec<usize>,
95        output_dtype: DataType,
96        output_primitive_dtype: DataType,
97        /// Length of the outermost level
98        output_len: usize,
99        data_lhs: (Series, Vec<Option<Bitmap>>),
100        data_rhs: (Series, Vec<Option<Bitmap>>),
101        swapped: bool,
102    }
103
104    /// This lets us separate some logic into `new()` to reduce the amount of
105    /// monomorphized code.
106    impl NumericFixedSizeListOpHelper {
107        /// Checks that:
108        /// * Dtypes are compatible:
109        ///   * list<->primitive | primitive<->list
110        ///   * list<->list both contain primitives (e.g. List<Int8>)
111        /// * Primitive dtypes match
112        /// * Lengths are compatible:
113        ///   * 1<->n | n<->1
114        ///   * n<->n
115        /// * Both sides have at least 1 non-NULL outer row.
116        ///
117        /// This returns an `Either` which may contain the final result to simplify
118        /// the implementation.
119        pub(super) fn execute_op(
120            op: NumericFixedSizeListOp,
121            lhs: Series,
122            rhs: Series,
123        ) -> PolarsResult<ArrayChunked> {
124            assert_eq!(lhs.chunks().len(), 1);
125            assert_eq!(rhs.chunks().len(), 1);
126
127            let dtype_lhs = lhs.dtype();
128            let dtype_rhs = rhs.dtype();
129
130            let prim_dtype_lhs = dtype_lhs.leaf_dtype();
131            let prim_dtype_rhs = dtype_rhs.leaf_dtype();
132
133            //
134            // Check leaf dtypes
135            //
136
137            if !(prim_dtype_lhs.is_supported_list_arithmetic_input()
138                && prim_dtype_rhs.is_supported_list_arithmetic_input())
139            {
140                polars_bail!(
141                    ComputeError: "cannot {} non-numeric inner dtypes: (left: {}, right: {})",
142                    op.0.name(), prim_dtype_lhs, prim_dtype_rhs
143                )
144            }
145
146            let output_primitive_dtype =
147                op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
148
149            fn is_array_type_at_all_levels(dtype: &DataType) -> bool {
150                match dtype {
151                    DataType::Array(inner, ..) => is_array_type_at_all_levels(inner),
152                    dt if dt.is_supported_list_arithmetic_input() => true,
153                    _ => false,
154                }
155            }
156
157            fn array_stride_and_widths(dtype: &DataType, widths: &mut Vec<usize>) -> usize {
158                if let DataType::Array(inner, size_inner) = dtype {
159                    widths.push(*size_inner);
160                    *size_inner * array_stride_and_widths(inner.as_ref(), widths)
161                } else {
162                    1
163                }
164            }
165
166            //
167            // Get broadcasting information and output length
168            //
169
170            let len_lhs = lhs.len();
171            let len_rhs = rhs.len();
172
173            let (broadcast, output_len) = match (len_lhs, len_rhs) {
174                (l, r) if l == r => (Broadcast::NoBroadcast, l),
175                (1, v) => (Broadcast::Left, v),
176                (v, 1) => (Broadcast::Right, v),
177                (l, r) => polars_bail!(
178                    ShapeMismatch:
179                    "cannot {} two columns of differing lengths: {} != {}",
180                    op.0.name(), l, r
181                ),
182            };
183
184            //
185            // Get validities for array levels
186            //
187
188            fn push_array_validities_recursive(s: &Series, out: &mut Vec<Option<Bitmap>>) {
189                let mut opt_arr = s.array().ok().map(|x| {
190                    assert_eq!(x.chunks().len(), 1);
191                    x.downcast_get(0).unwrap()
192                });
193
194                while let Some(arr) = opt_arr {
195                    // Push none if all-valid, this can potentially save some `repeat_bitmap()`
196                    // materializations on broadcasting paths.
197                    out.push(arr.validity().filter(|x| x.unset_bits() > 0).cloned());
198                    opt_arr = arr.values().as_any().downcast_ref::<FixedSizeListArray>();
199                }
200            }
201
202            let mut array_validities_lhs = vec![];
203            let mut array_validities_rhs = vec![];
204
205            push_array_validities_recursive(&lhs, &mut array_validities_lhs);
206            push_array_validities_recursive(&rhs, &mut array_validities_rhs);
207
208            let op_err_msg = |err_reason: &str| {
209                polars_err!(
210                    InvalidOperation:
211                    "cannot {} columns: {}: (left: {}, right: {})",
212                    op.0.name(), err_reason, dtype_lhs, dtype_rhs,
213                )
214            };
215
216            let ensure_array_type_at_all_levels = |dtype: &DataType| {
217                if !is_array_type_at_all_levels(dtype) {
218                    Err(op_err_msg("dtype was not array on all nesting levels"))
219                } else {
220                    Ok(())
221                }
222            };
223
224            //
225            // Check full dtypes and get output widths
226            //
227
228            let mut output_widths = vec![];
229
230            let (op_apply_type, stride, output_dtype) = match (dtype_lhs, dtype_rhs) {
231                (dtype_lhs @ DataType::Array(..), dtype_rhs @ DataType::Array(..)) => {
232                    // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user
233                    // directly adds 2 series together it bypasses the DSL.
234                    // This is currently duplicated code and should be replaced one day with an assert after Series ops get
235                    // checked properly.
236
237                    if dtype_lhs.cast_leaf(output_primitive_dtype.clone())
238                        != dtype_rhs.cast_leaf(output_primitive_dtype.clone())
239                    {
240                        return Err(op_err_msg("differing dtypes"));
241                    };
242
243                    // We only check dtype_lhs since we already checked dtype_lhs == dtype_rhs
244                    ensure_array_type_at_all_levels(dtype_lhs)?;
245
246                    let stride = array_stride_and_widths(dtype_lhs, &mut output_widths);
247
248                    // For array<->array without broadcasting we return early here to avoid the rest
249                    // of the setup code and dispatch layers.
250                    if let Broadcast::NoBroadcast = broadcast {
251                        let out = op.0.apply_series(
252                            &lhs.get_leaf_array().cast(&output_primitive_dtype)?,
253                            &rhs.get_leaf_array().cast(&output_primitive_dtype)?,
254                        );
255
256                        return Ok(finish_array_to_array_no_broadcast(
257                            lhs.name().clone(),
258                            &output_widths,
259                            output_len,
260                            &array_validities_lhs,
261                            &array_validities_rhs,
262                            out,
263                        ));
264                    }
265
266                    (BinaryOpApplyType::ListToList, stride, dtype_lhs)
267                },
268                (array_dtype @ DataType::Array(..), x)
269                    if x.is_supported_list_arithmetic_input() =>
270                {
271                    ensure_array_type_at_all_levels(array_dtype)?;
272
273                    let stride = array_stride_and_widths(array_dtype, &mut output_widths);
274                    (BinaryOpApplyType::ListToPrimitive, stride, array_dtype)
275                },
276                (x, array_dtype @ DataType::Array(..))
277                    if x.is_supported_list_arithmetic_input() =>
278                {
279                    ensure_array_type_at_all_levels(array_dtype)?;
280
281                    let stride = array_stride_and_widths(array_dtype, &mut output_widths);
282                    (BinaryOpApplyType::PrimitiveToList, stride, array_dtype)
283                },
284                (l, r) => polars_bail!(
285                    InvalidOperation:
286                    "cannot {} dtypes: {} != {}",
287                    op.0.name(), l, r,
288                ),
289            };
290
291            let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
292
293            assert!(!output_widths.is_empty());
294
295            if cfg!(debug_assertions) {
296                match (array_validities_lhs.len(), array_validities_rhs.len()) {
297                    (l, r) if l == output_widths.len() && l == r && l > 0 => {},
298                    (v, 0) | (0, v) if v == output_widths.len() => {},
299                    _ => panic!(), // One side should have been an array.
300                }
301            }
302
303            if output_len == 0
304                || (matches!(
305                    &op_apply_type,
306                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
307                ) && lhs.rechunk_validity().is_some_and(|x| x.set_bits() == 0))
308                || (matches!(
309                    &op_apply_type,
310                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
311                ) && rhs.rechunk_validity().is_some_and(|x| x.set_bits() == 0))
312            {
313                let DataType::Array(inner_dtype, width) = output_dtype else {
314                    unreachable!()
315                };
316
317                Ok(ArrayChunked::full_null_with_dtype(
318                    lhs.name().clone(),
319                    output_len,
320                    inner_dtype.as_ref(),
321                    width,
322                ))
323            } else {
324                Self {
325                    op,
326                    output_name: lhs.name().clone(),
327                    op_apply_type,
328                    broadcast,
329                    stride,
330                    output_widths,
331                    output_dtype,
332                    output_primitive_dtype,
333                    output_len,
334                    data_lhs: (lhs, array_validities_lhs),
335                    data_rhs: (rhs, array_validities_rhs),
336                    swapped: false,
337                }
338                .finish()
339            }
340        }
341
342        pub(super) fn finish(mut self) -> PolarsResult<ArrayChunked> {
343            // We have physical codepaths for a subset of the possible combinations of broadcasting and
344            // column types. The remaining combinations are handled by dispatching to the physical
345            // codepaths after operand swapping.
346            //
347            // # Physical impl table
348            // Legend
349            // * |  N  | // impl "N"
350            // * | [N] | // dispatches to impl "N"
351            //
352            //                  |  L  |  N  |  R  | // Broadcast (L)eft, (N)oBroadcast, (R)ight
353            // ListToList       | [1] |  0  |  1  |
354            // ListToPrimitive  |  2  |  3  |  4  |
355            // PrimitiveToList  | [4] | [3] | [2] |
356
357            self.swapped = true;
358
359            match (&self.op_apply_type, &self.broadcast) {
360                // Mostly the same as ListNumericOp, however with fixed size list we also have
361                // (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) as a physical impl.
362                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => unreachable!(), // We return earlier for this
363                (BinaryOpApplyType::ListToList, Broadcast::Right)
364                | (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
365                | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
366                | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
367                    self.swapped = false;
368                    self._finish_impl_dispatch()
369                },
370                (BinaryOpApplyType::ListToList, Broadcast::Left) => {
371                    self.broadcast = Broadcast::Right;
372
373                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
374                    self._finish_impl_dispatch()
375                },
376                (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
377                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
378                    self.broadcast = Broadcast::Left;
379
380                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
381                    self._finish_impl_dispatch()
382                },
383
384                (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
385                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
386
387                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
388                    self._finish_impl_dispatch()
389                },
390                (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
391                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
392                    self.broadcast = Broadcast::Right;
393
394                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
395                    self._finish_impl_dispatch()
396                },
397            }
398        }
399
400        fn _finish_impl_dispatch(&mut self) -> PolarsResult<ArrayChunked> {
401            let output_dtype = self.output_dtype.clone();
402            let output_len = self.output_len;
403
404            let prim_lhs = self
405                .data_lhs
406                .0
407                .get_leaf_array()
408                .cast(&self.output_primitive_dtype)?
409                .rechunk();
410            let prim_rhs = self
411                .data_rhs
412                .0
413                .get_leaf_array()
414                .cast(&self.output_primitive_dtype)?
415                .rechunk();
416
417            debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
418            let prim_dtype = prim_lhs.dtype();
419            debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
420
421            // Safety: Leaf dtypes have been checked to be numeric by `try_new()`
422            let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
423                self._finish_impl::<$T>(prim_lhs, prim_rhs)
424            });
425
426            debug_assert_eq!(out.dtype(), &output_dtype);
427            assert_eq!(out.len(), output_len);
428
429            Ok(out)
430        }
431
432        /// Internal use only - contains physical impls.
433        fn _finish_impl<T: PolarsNumericType>(
434            &mut self,
435            prim_s_lhs: Series,
436            prim_s_rhs: Series,
437        ) -> ArrayChunked
438        where
439            T::Native: PlNumArithmetic,
440            PrimitiveArray<T::Native>:
441                polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
442            T::Native: Zero + IsFloat,
443        {
444            let mut arr_lhs = {
445                let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
446                assert_eq!(ca.chunks().len(), 1);
447                ca.downcast_get(0).unwrap().clone()
448            };
449
450            let mut arr_rhs = {
451                let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
452                assert_eq!(ca.chunks().len(), 1);
453                ca.downcast_get(0).unwrap().clone()
454            };
455
456            self.op.0.prepare_numeric_op_side_validities::<T>(
457                &mut arr_lhs,
458                &mut arr_rhs,
459                self.swapped,
460            );
461
462            match (&self.op_apply_type, &self.broadcast) {
463                (BinaryOpApplyType::ListToList, Broadcast::Right) => {
464                    let mut out_vec: Vec<T::Native> =
465                        Vec::with_capacity(self.output_len * self.stride);
466                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
467                    let stride = self.stride;
468
469                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
470                        unsafe {
471                            for outer_idx in 0..self.output_len {
472                                for inner_idx in 0..stride {
473                                    let l = arr_lhs.value_unchecked(stride * outer_idx + inner_idx);
474                                    let r = arr_rhs.value_unchecked(inner_idx);
475
476                                    *out_ptr.add(stride * outer_idx + inner_idx) = $OP(l, r);
477                                }
478                            }
479                        }
480                    });
481
482                    unsafe { out_vec.set_len(self.output_len * self.stride) };
483
484                    let leaf_validity = combine_validities_and(
485                        arr_lhs.validity(),
486                        arr_rhs
487                            .validity()
488                            .map(|x| repeat_bitmap(x, self.output_len))
489                            .as_ref(),
490                    );
491
492                    let arr =
493                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
494
495                    let (_, validities_lhs) = std::mem::take(&mut self.data_lhs);
496                    let (_, mut validities_rhs) = std::mem::take(&mut self.data_rhs);
497
498                    for v in validities_rhs.iter_mut() {
499                        if let Some(v) = v.as_mut() {
500                            *v = repeat_bitmap(v, self.output_len);
501                        }
502                    }
503
504                    finish_array_to_array_no_broadcast(
505                        std::mem::take(&mut self.output_name),
506                        &self.output_widths,
507                        self.output_len,
508                        &validities_lhs,
509                        &validities_rhs,
510                        Box::new(arr),
511                    )
512                },
513                (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
514                    let mut out_vec: Vec<T::Native> =
515                        Vec::with_capacity(self.output_len * self.stride);
516                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
517                    let stride = self.stride;
518
519                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
520                        unsafe {
521                            for outer_idx in 0..self.output_len {
522                                let r = arr_rhs.value_unchecked(outer_idx);
523
524                                for inner_idx in 0..stride {
525                                    let l = arr_lhs.value_unchecked(inner_idx);
526
527                                    *out_ptr.add(stride * outer_idx + inner_idx) = $OP(l, r);
528                                }
529                            }
530                        }
531                    });
532
533                    unsafe { out_vec.set_len(self.output_len * self.stride) };
534
535                    let leaf_validity = combine_validities_array_to_primitive_no_broadcast(
536                        arr_lhs
537                            .validity()
538                            .map(|x| repeat_bitmap(x, self.output_len))
539                            .as_ref(),
540                        arr_rhs.validity(),
541                        self.stride,
542                    );
543
544                    let arr =
545                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
546
547                    let (_, mut validities) = std::mem::take(&mut self.data_lhs);
548
549                    for v in validities.iter_mut() {
550                        if let Some(v) = v.as_mut() {
551                            *v = repeat_bitmap(v, self.output_len);
552                        }
553                    }
554
555                    finish_with_level_validities(
556                        std::mem::take(&mut self.output_name),
557                        &self.output_widths,
558                        self.output_len,
559                        &validities,
560                        Box::new(arr),
561                    )
562                },
563                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
564                    let mut out_vec: Vec<T::Native> =
565                        Vec::with_capacity(self.output_len * self.stride);
566                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
567                    let stride = self.stride;
568
569                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
570                        unsafe {
571                            for outer_idx in 0..self.output_len {
572                                let r = arr_rhs.value_unchecked(outer_idx);
573
574                                for inner_idx in 0..stride {
575                                    let idx = stride * outer_idx + inner_idx;
576                                    let l = arr_lhs.value_unchecked(idx);
577
578                                    *out_ptr.add(idx) = $OP(l, r);
579                                }
580                            }
581                        }
582                    });
583
584                    unsafe { out_vec.set_len(self.output_len * self.stride) };
585
586                    let leaf_validity = combine_validities_array_to_primitive_no_broadcast(
587                        arr_lhs.validity(),
588                        arr_rhs.validity(),
589                        self.stride,
590                    );
591
592                    let arr =
593                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
594
595                    let (_, validities) = std::mem::take(&mut self.data_lhs);
596
597                    finish_with_level_validities(
598                        std::mem::take(&mut self.output_name),
599                        &self.output_widths,
600                        self.output_len,
601                        &validities,
602                        Box::new(arr),
603                    )
604                },
605                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
606                    assert_eq!(arr_rhs.len(), 1);
607
608                    let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
609                        // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL.
610                        let (_, validities) = std::mem::take(&mut self.data_lhs);
611                        return finish_with_level_validities(
612                            std::mem::take(&mut self.output_name),
613                            &self.output_widths,
614                            self.output_len,
615                            &validities,
616                            Box::new(
617                                arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
618                                    false,
619                                    arr_lhs.len(),
620                                ))),
621                            ),
622                        );
623                    };
624
625                    let arr = self
626                        .op
627                        .0
628                        .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
629
630                    let (_, validities) = std::mem::take(&mut self.data_lhs);
631
632                    finish_with_level_validities(
633                        std::mem::take(&mut self.output_name),
634                        &self.output_widths,
635                        self.output_len,
636                        &validities,
637                        Box::new(arr),
638                    )
639                },
640                v @ (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
641                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
642                | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
643                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
644                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
645                    if cfg!(debug_assertions) {
646                        panic!("operation was not re-written: {:?}", v)
647                    } else {
648                        unreachable!()
649                    }
650                },
651            }
652        }
653    }
654
655    /// Build the result of an array<->array operation.
656    #[inline(never)]
657    fn finish_array_to_array_no_broadcast(
658        output_name: PlSmallStr,
659        widths: &[usize],
660        outer_len: usize,
661        validities_lhs: &[Option<Bitmap>],
662        validities_rhs: &[Option<Bitmap>],
663        output_leaf_array: Box<dyn Array>,
664    ) -> ArrayChunked {
665        assert_eq!(
666            [widths.len(), validities_lhs.len(), validities_rhs.len()],
667            [widths.len(); 3]
668        );
669
670        let mut builder = FixedSizeListLevelBuilder::new(outer_len, widths);
671
672        let validities_iter = validities_lhs
673            .iter()
674            .zip(validities_rhs)
675            .map(|(l, r)| combine_validities_and(l.as_ref(), r.as_ref()));
676        // `.rev()` - we build this from the inner level.
677        let mut iter = widths.iter().zip(validities_iter).rev();
678
679        let mut out = {
680            let (width, opt_validity) = iter.next().unwrap();
681            builder.build_level(*width, opt_validity, output_leaf_array)
682        };
683
684        for (width, opt_validity) in iter {
685            out = builder.build_level(*width, opt_validity, Box::new(out))
686        }
687
688        ArrayChunked::with_chunk(output_name, out)
689    }
690
691    /// Used when we are operating between array<->primitive, as in that case we only need the
692    /// validities from the array side.
693    #[inline(never)]
694    fn finish_with_level_validities(
695        output_name: PlSmallStr,
696        widths: &[usize],
697        outer_len: usize,
698        validities: &[Option<Bitmap>],
699        output_leaf_array: Box<dyn Array>,
700    ) -> ArrayChunked {
701        assert_eq!(widths.len(), validities.len());
702
703        let mut builder = FixedSizeListLevelBuilder::new(outer_len, widths);
704
705        let validities_iter = validities.iter().cloned();
706        // `.rev()` - we build this from the inner level.
707        let mut iter = widths.iter().zip(validities_iter).rev();
708
709        let mut out = {
710            let (width, opt_validity) = iter.next().unwrap();
711            builder.build_level(*width, opt_validity, output_leaf_array)
712        };
713
714        for (width, opt_validity) in iter {
715            out = builder.build_level(*width, opt_validity, Box::new(out))
716        }
717
718        ArrayChunked::with_chunk(output_name, out)
719    }
720
721    /// ```text
722    /// array      [x, x, x, x, ..] (stride 2)
723    ///             | /   | /
724    ///             |/    |/
725    /// primitive  [x,    x,    ..]
726    /// ```
727    #[inline(never)]
728    fn combine_validities_array_to_primitive_no_broadcast(
729        array_leaf_validity: Option<&Bitmap>,
730        primitive_validity: Option<&Bitmap>,
731        stride: usize,
732    ) -> Option<Bitmap> {
733        match (array_leaf_validity, primitive_validity) {
734            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
735            (Some(v), None) => return Some(v.clone()),
736            // Materialize a full-true validity to re-use the codepath, as we still
737            // need to spread the bits from the RHS to the correct positions.
738            (None, Some(v)) => Some((Bitmap::new_with_value(true, stride * v.len()).make_mut(), v)),
739            (None, None) => None,
740        }
741        .map(|(mut validity_out, primitive_validity)| {
742            assert_eq!(validity_out.len(), stride * primitive_validity.len());
743
744            unsafe {
745                for outer_idx in 0..primitive_validity.len() {
746                    let r = primitive_validity.get_bit_unchecked(outer_idx);
747
748                    for inner_idx in 0..stride {
749                        let idx = stride * outer_idx + inner_idx;
750                        let l = validity_out.get_unchecked(idx);
751
752                        validity_out.set_unchecked(idx, l & r);
753                    }
754                }
755            }
756
757            validity_out.freeze()
758        })
759    }
760
761    /// Returns `n_repeats` concatenated copies of the bitmap.
762    #[inline(never)]
763    fn repeat_bitmap(bitmap: &Bitmap, n_repeats: usize) -> Bitmap {
764        let mut out = BitmapBuilder::with_capacity(bitmap.len() * n_repeats);
765
766        for _ in 0..n_repeats {
767            for bit in bitmap.iter() {
768                unsafe { out.push_unchecked(bit) }
769            }
770        }
771
772        out.freeze()
773    }
774
775    struct FixedSizeListLevelBuilder {
776        heights: <Vec<usize> as IntoIterator>::IntoIter,
777    }
778
779    impl FixedSizeListLevelBuilder {
780        fn new(outer_len: usize, widths: &[usize]) -> Self {
781            let mut current_height = outer_len;
782            // We need to calculate heights here like this rather than dividing the stride because
783            // there can be 0-width arrays.
784            let mut heights = Vec::with_capacity(widths.len());
785
786            heights.push(current_height);
787            heights.extend(widths.iter().take(widths.len() - 1).map(|width| {
788                current_height *= *width;
789                current_height
790            }));
791
792            Self {
793                heights: heights.into_iter(),
794            }
795        }
796    }
797
798    impl FixedSizeListLevelBuilder {
799        fn build_level(
800            &mut self,
801            width: usize,
802            opt_validity: Option<Bitmap>,
803            inner_array: Box<dyn Array>,
804        ) -> FixedSizeListArray {
805            let level_height = self.heights.next_back().unwrap();
806            assert_eq!(inner_array.len(), level_height * width);
807
808            FixedSizeListArray::new(
809                ArrowDataType::FixedSizeList(
810                    Box::new(ArrowField::new(
811                        PlSmallStr::from_static("item"),
812                        inner_array.dtype().clone(),
813                        // is_nullable, we always set true otherwise the Eq kernels would panic
814                        // when they assert == on the arrow `Field`
815                        true,
816                    )),
817                    width,
818                ),
819                level_height,
820                inner_array,
821                opt_validity,
822            )
823        }
824    }
825}