polars_core/series/arithmetic/
list.rs

1//! Allow arithmetic operations for ListChunked.
2//! use polars_error::{feature_gated, PolarsResult};
3
4use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8
9impl NumOpsDispatchInner for ListType {
10    fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
11        NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
12    }
13
14    fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
15        NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
16    }
17
18    fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
19        NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
20    }
21
22    fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
23        NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
24    }
25
26    fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
27        NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
28    }
29}
30
31#[derive(Clone)]
32pub struct NumericListOp(NumericOp);
33
34impl NumericListOp {
35    pub fn add() -> Self {
36        Self(NumericOp::Add)
37    }
38
39    pub fn sub() -> Self {
40        Self(NumericOp::Sub)
41    }
42
43    pub fn mul() -> Self {
44        Self(NumericOp::Mul)
45    }
46
47    pub fn div() -> Self {
48        Self(NumericOp::Div)
49    }
50
51    pub fn rem() -> Self {
52        Self(NumericOp::Rem)
53    }
54
55    pub fn floor_div() -> Self {
56        Self(NumericOp::FloorDiv)
57    }
58}
59
60impl NumericListOp {
61    #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
62    pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
63        feature_gated!("list_arithmetic", {
64            use either::Either;
65
66            // `trim_to_normalized_offsets` ensures we don't perform excessive
67            // memory allocation / compute on memory regions that have been
68            // sliced out.
69            let lhs = lhs.list_rechunk_and_trim_to_normalized_offsets();
70            let rhs = rhs.list_rechunk_and_trim_to_normalized_offsets();
71
72            let binary_op_exec = match ListNumericOpHelper::try_new(
73                self.clone(),
74                lhs.name().clone(),
75                lhs.dtype(),
76                rhs.dtype(),
77                lhs.len(),
78                rhs.len(),
79                {
80                    let (a, b) = lhs.list_offsets_and_validities_recursive();
81                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
82                    (a, b, lhs.clone())
83                },
84                {
85                    let (a, b) = rhs.list_offsets_and_validities_recursive();
86                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
87                    (a, b, rhs.clone())
88                },
89                lhs.rechunk_validity(),
90                rhs.rechunk_validity(),
91            )? {
92                Either::Left(v) => v,
93                Either::Right(ca) => return Ok(ca.into_series()),
94            };
95
96            Ok(binary_op_exec.finish()?.into_series())
97        })
98    }
99}
100
101#[cfg(feature = "list_arithmetic")]
102use inner::ListNumericOpHelper;
103
104#[cfg(feature = "list_arithmetic")]
105mod inner {
106    use arrow::bitmap::Bitmap;
107    use arrow::compute::utils::combine_validities_and;
108    use arrow::offset::OffsetsBuffer;
109    use either::Either;
110    use list_utils::with_match_pl_num_arith;
111    use num_traits::Zero;
112    use polars_compute::arithmetic::pl_num::PlNumArithmetic;
113    use polars_utils::float::IsFloat;
114
115    use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
116    use super::super::*;
117
118    /// Utility to perform a binary operation between the primitive values of
119    /// 2 columns, where at least one of the columns is a `ListChunked` type.
120    pub(super) struct ListNumericOpHelper {
121        op: NumericListOp,
122        output_name: PlSmallStr,
123        op_apply_type: BinaryOpApplyType,
124        broadcast: Broadcast,
125        output_dtype: DataType,
126        output_primitive_dtype: DataType,
127        output_len: usize,
128        /// Outer validity of the result, we always materialize this to reduce the
129        /// amount of code paths we need.
130        outer_validity: Bitmap,
131        // The series are stored as they are used for list broadcasting.
132        data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
133        data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
134        list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
135        swapped: bool,
136    }
137
138    /// This lets us separate some logic into `new()` to reduce the amount of
139    /// monomorphized code.
140    impl ListNumericOpHelper {
141        /// Checks that:
142        /// * Dtypes are compatible:
143        ///   * list<->primitive | primitive<->list
144        ///   * list<->list both contain primitives (e.g. List<Int8>)
145        /// * Primitive dtypes match
146        /// * Lengths are compatible:
147        ///   * 1<->n | n<->1
148        ///   * n<->n
149        /// * Both sides have at least 1 non-NULL outer row.
150        ///
151        /// Does not check:
152        /// * Whether the offsets are aligned for list<->list, this will be checked during execution.
153        ///
154        /// This returns an `Either` which may contain the final result to simplify
155        /// the implementation.
156        #[allow(clippy::too_many_arguments)]
157        pub(super) fn try_new(
158            op: NumericListOp,
159            output_name: PlSmallStr,
160            dtype_lhs: &DataType,
161            dtype_rhs: &DataType,
162            len_lhs: usize,
163            len_rhs: usize,
164            data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
165            data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
166            validity_lhs: Option<Bitmap>,
167            validity_rhs: Option<Bitmap>,
168        ) -> PolarsResult<Either<Self, ListChunked>> {
169            let prim_dtype_lhs = dtype_lhs.leaf_dtype();
170            let prim_dtype_rhs = dtype_rhs.leaf_dtype();
171
172            let output_primitive_dtype =
173                op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
174
175            fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
176                match dtype {
177                    DataType::List(inner) => is_list_type_at_all_levels(inner),
178                    dt if dt.is_supported_list_arithmetic_input() => true,
179                    _ => false,
180                }
181            }
182
183            let op_err_msg = |err_reason: &str| {
184                polars_err!(
185                    InvalidOperation:
186                    "cannot {} columns: {}: (left: {}, right: {})",
187                    op.0.name(), err_reason, dtype_lhs, dtype_rhs,
188                )
189            };
190
191            let ensure_list_type_at_all_levels = |dtype: &DataType| {
192                if !is_list_type_at_all_levels(dtype) {
193                    Err(op_err_msg("dtype was not list on all nesting levels"))
194                } else {
195                    Ok(())
196                }
197            };
198
199            let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
200                (l @ DataType::List(a), r @ DataType::List(b)) => {
201                    // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user
202                    // directly adds 2 series together it bypasses the DSL.
203                    // This is currently duplicated code and should be replaced one day with an assert after Series ops get
204                    // checked properly.
205                    if ![a, b]
206                        .into_iter()
207                        .all(|x| x.is_supported_list_arithmetic_input())
208                    {
209                        polars_bail!(
210                            InvalidOperation:
211                            "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
212                            op.0.name(), l, r,
213                        );
214                    }
215                    (BinaryOpApplyType::ListToList, l)
216                },
217                (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
218                    ensure_list_type_at_all_levels(list_dtype)?;
219                    (BinaryOpApplyType::ListToPrimitive, list_dtype)
220                },
221                (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
222                    ensure_list_type_at_all_levels(list_dtype)?;
223                    (BinaryOpApplyType::PrimitiveToList, list_dtype)
224                },
225                (l, r) => polars_bail!(
226                    InvalidOperation:
227                    "{} operation not supported for dtypes: {} != {}",
228                    op.0.name(), l, r,
229                ),
230            };
231
232            let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
233
234            let (broadcast, output_len) = match (len_lhs, len_rhs) {
235                (l, r) if l == r => (Broadcast::NoBroadcast, l),
236                (1, v) => (Broadcast::Left, v),
237                (v, 1) => (Broadcast::Right, v),
238                (l, r) => polars_bail!(
239                    ShapeMismatch:
240                    "cannot {} two columns of differing lengths: {} != {}",
241                    op.0.name(), l, r
242                ),
243            };
244
245            let DataType::List(output_inner_dtype) = &output_dtype else {
246                unreachable!()
247            };
248
249            // # NULL semantics
250            // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]]
251            //   * Essentially as if the NULL primitive was added to every primitive in the row of the list column.
252            // * NULL (List[Int64]) + 1   (Int64)       => NULL
253            // * NULL (List[Int64]) + [1] (List[Int64]) => NULL
254
255            if output_len == 0
256                || (matches!(
257                    &op_apply_type,
258                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
259                ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
260                || (matches!(
261                    &op_apply_type,
262                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
263                ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
264            {
265                return Ok(Either::Right(ListChunked::full_null_with_dtype(
266                    output_name.clone(),
267                    output_len,
268                    output_inner_dtype.as_ref(),
269                )));
270            }
271
272            // At this point:
273            // * All unit length list columns have a valid outer value.
274
275            // The outer validity is just the validity of any non-broadcasting lists.
276            let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
277                // Both lists with same length, we combine the validity.
278                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
279                    combine_validities_and(l.as_ref(), r.as_ref())
280                },
281                // Match all other combinations that have non-broadcasting lists.
282                (
283                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
284                    Broadcast::NoBroadcast | Broadcast::Right,
285                    v,
286                    _,
287                )
288                | (
289                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
290                    Broadcast::NoBroadcast | Broadcast::Left,
291                    _,
292                    v,
293                ) => v,
294                _ => None,
295            }
296            .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
297
298            Ok(Either::Left(Self {
299                op,
300                output_name,
301                op_apply_type,
302                broadcast,
303                output_dtype: output_dtype.clone(),
304                output_primitive_dtype,
305                output_len,
306                outer_validity,
307                data_lhs,
308                data_rhs,
309                list_to_prim_lhs: None,
310                swapped: false,
311            }))
312        }
313
314        pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
315            // We have physical codepaths for a subset of the possible combinations of broadcasting and
316            // column types. The remaining combinations are handled by dispatching to the physical
317            // codepaths after operand swapping and/or materialized broadcasting.
318            //
319            // # Physical impl table
320            // Legend
321            // * |  N  | // impl "N"
322            // * | [N] | // dispatches to impl "N"
323            //
324            //                  |  L  |  N  |  R  | // Broadcast (L)eft, (N)oBroadcast, (R)ight
325            // ListToList       | [1] |  0  |  1  |
326            // ListToPrimitive  | [2] |  2  |  3  | // list broadcasting just materializes and dispatches to NoBroadcast
327            // PrimitiveToList  | [3] | [2] | [2] |
328
329            self.swapped = true;
330
331            match (&self.op_apply_type, &self.broadcast) {
332                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
333                | (BinaryOpApplyType::ListToList, Broadcast::Right)
334                | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
335                | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
336                    self.swapped = false;
337                    self._finish_impl_dispatch()
338                },
339                (BinaryOpApplyType::ListToList, Broadcast::Left) => {
340                    self.broadcast = Broadcast::Right;
341
342                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
343                    self._finish_impl_dispatch()
344                },
345                (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
346                    self.list_to_prim_lhs
347                        .replace(Self::materialize_broadcasted_list(
348                            &mut self.data_lhs,
349                            self.output_len,
350                            &self.output_primitive_dtype,
351                        ));
352
353                    self.broadcast = Broadcast::NoBroadcast;
354
355                    // This does not swap! We are just dispatching to `NoBroadcast`
356                    // after materializing the broadcasted list array.
357                    self.swapped = false;
358                    self._finish_impl_dispatch()
359                },
360                (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
361                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
362
363                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
364                    self._finish_impl_dispatch()
365                },
366                (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
367                    // We materialize the list columns with `new_from_index`, as otherwise we'd have to
368                    // implement logic that broadcasts the offsets and validities across multiple levels
369                    // of nesting. But we will re-use the materialized memory to store the result.
370
371                    self.list_to_prim_lhs
372                        .replace(Self::materialize_broadcasted_list(
373                            &mut self.data_rhs,
374                            self.output_len,
375                            &self.output_primitive_dtype,
376                        ));
377
378                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
379                    self.broadcast = Broadcast::NoBroadcast;
380
381                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
382                    self._finish_impl_dispatch()
383                },
384                (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
385                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
386                    self.broadcast = Broadcast::Right;
387
388                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
389                    self._finish_impl_dispatch()
390                },
391            }
392        }
393
394        fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
395            let output_dtype = self.output_dtype.clone();
396            let output_len = self.output_len;
397
398            let prim_lhs = self
399                .data_lhs
400                .2
401                .get_leaf_array()
402                .cast(&self.output_primitive_dtype)?
403                .rechunk();
404            let prim_rhs = self
405                .data_rhs
406                .2
407                .get_leaf_array()
408                .cast(&self.output_primitive_dtype)?
409                .rechunk();
410
411            debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
412            let prim_dtype = prim_lhs.dtype();
413            debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
414
415            // Safety: Leaf dtypes have been checked to be numeric by `try_new()`
416            let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
417                self._finish_impl::<$T>(prim_lhs, prim_rhs)
418            })?;
419
420            debug_assert_eq!(out.dtype(), &output_dtype);
421            assert_eq!(out.len(), output_len);
422
423            Ok(out)
424        }
425
426        /// Internal use only - contains physical impls.
427        fn _finish_impl<T: PolarsNumericType>(
428            &mut self,
429            prim_s_lhs: Series,
430            prim_s_rhs: Series,
431        ) -> PolarsResult<ListChunked>
432        where
433            T::Native: PlNumArithmetic,
434            PrimitiveArray<T::Native>:
435                polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
436            T::Native: Zero + IsFloat,
437        {
438            #[inline(never)]
439            fn check_mismatch_pos(
440                mismatch_pos: usize,
441                offsets_lhs: &OffsetsBuffer<i64>,
442                offsets_rhs: &OffsetsBuffer<i64>,
443            ) -> PolarsResult<()> {
444                if mismatch_pos < offsets_lhs.len_proxy() {
445                    // RHS could be broadcasted
446                    let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
447                        0
448                    } else {
449                        mismatch_pos
450                    });
451                    polars_bail!(
452                        ShapeMismatch:
453                        "list lengths differed at index {}: {} != {}",
454                        mismatch_pos,
455                        offsets_lhs.length_at(mismatch_pos), len_r
456                    )
457                }
458                Ok(())
459            }
460
461            let mut arr_lhs = {
462                let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
463                assert_eq!(ca.chunks().len(), 1);
464                ca.downcast_get(0).unwrap().clone()
465            };
466
467            let mut arr_rhs = {
468                let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
469                assert_eq!(ca.chunks().len(), 1);
470                ca.downcast_get(0).unwrap().clone()
471            };
472
473            match (&self.op_apply_type, &self.broadcast) {
474                // We skip for this because it dispatches to `ArithmeticKernel`, which handles the
475                // validities for us.
476                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
477                _ if self.list_to_prim_lhs.is_none() => {
478                    self.op.0.prepare_numeric_op_side_validities::<T>(
479                        &mut arr_lhs,
480                        &mut arr_rhs,
481                        self.swapped,
482                    )
483                },
484                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
485                    // `self.list_to_prim_lhs` is `Some(_)`, this is handled later.
486                },
487                _ => unreachable!(),
488            }
489
490            //
491            // General notes
492            // * Lists can be:
493            //   * Sliced, in which case the primitive/leaf array needs to be indexed starting from an
494            //     offset instead of 0.
495            //   * Masked, in which case the masked rows are permitted to have non-matching widths.
496            //
497
498            let out = match (&self.op_apply_type, &self.broadcast) {
499                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
500                    let offsets_lhs = &self.data_lhs.0[0];
501                    let offsets_rhs = &self.data_rhs.0[0];
502
503                    assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
504
505                    // Output primitive (and optional validity) are aligned to the LHS input.
506                    let n_values = arr_lhs.len();
507                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
508                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
509
510                    // Counter that stops being incremented at the first row position with mismatching
511                    // list lengths.
512                    let mut mismatch_pos = 0;
513
514                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
515                        for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
516                            .offset_and_length_iter()
517                            .zip(offsets_rhs.offset_and_length_iter())
518                            .enumerate()
519                        {
520                            if
521                                (mismatch_pos == i)
522                                & (
523                                    (lhs_len == rhs_len)
524                                    | unsafe { !self.outer_validity.get_bit_unchecked(i) }
525                                )
526                            {
527                                mismatch_pos += 1;
528                            }
529
530                            // Both sides are lists, we restrict the index to the min length to avoid
531                            // OOB memory access.
532                            let len: usize = lhs_len.min(rhs_len);
533
534                            for i in 0..len {
535                                let l_idx = i + lhs_start;
536                                let r_idx = i + rhs_start;
537
538                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
539                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
540                                let v = $OP(l, r);
541
542                                unsafe { out_ptr.add(l_idx).write(v) };
543                            }
544                        }
545                    });
546
547                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
548
549                    unsafe { out_vec.set_len(n_values) };
550
551                    /// Reduce monomorphization
552                    #[inline(never)]
553                    fn combine_validities_list_to_list_no_broadcast(
554                        offsets_lhs: &OffsetsBuffer<i64>,
555                        offsets_rhs: &OffsetsBuffer<i64>,
556                        validity_lhs: Option<&Bitmap>,
557                        validity_rhs: Option<&Bitmap>,
558                        len_lhs: usize,
559                    ) -> Option<Bitmap> {
560                        match (validity_lhs, validity_rhs) {
561                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
562                            (Some(v), None) => return Some(v.clone()),
563                            (None, Some(v)) => {
564                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
565                            },
566                            (None, None) => None,
567                        }
568                        .map(|(mut validity_out, validity_rhs)| {
569                            for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
570                                .offset_and_length_iter()
571                                .zip(offsets_rhs.offset_and_length_iter())
572                            {
573                                let len: usize = lhs_len.min(rhs_len);
574
575                                for i in 0..len {
576                                    let l_idx = i + lhs_start;
577                                    let r_idx = i + rhs_start;
578
579                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
580                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
581                                    let is_valid = l_valid & r_valid;
582
583                                    // Size and alignment of validity vec are based on LHS.
584                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
585                                }
586                            }
587
588                            validity_out.freeze()
589                        })
590                    }
591
592                    let leaf_validity = combine_validities_list_to_list_no_broadcast(
593                        offsets_lhs,
594                        offsets_rhs,
595                        arr_lhs.validity(),
596                        arr_rhs.validity(),
597                        arr_lhs.len(),
598                    );
599
600                    let arr =
601                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
602
603                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
604                    assert_eq!(offsets.len(), 1);
605
606                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
607                },
608                (BinaryOpApplyType::ListToList, Broadcast::Right) => {
609                    let offsets_lhs = &self.data_lhs.0[0];
610                    let offsets_rhs = &self.data_rhs.0[0];
611
612                    // Output primitive (and optional validity) are aligned to the LHS input.
613                    let n_values = arr_lhs.len();
614                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
615                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
616
617                    assert_eq!(offsets_rhs.len_proxy(), 1);
618                    let rhs_start = *offsets_rhs.first() as usize;
619                    let width = offsets_rhs.range() as usize;
620
621                    let mut mismatch_pos = 0;
622
623                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
624                        for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
625                            if ((lhs_len == width) & (mismatch_pos == i))
626                                | unsafe { !self.outer_validity.get_bit_unchecked(i) }
627                            {
628                                mismatch_pos += 1;
629                            }
630
631                            let len: usize = lhs_len.min(width);
632
633                            for i in 0..len {
634                                let l_idx = i + lhs_start;
635                                let r_idx = i + rhs_start;
636
637                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
638                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
639                                let v = $OP(l, r);
640
641                                unsafe {
642                                    out_ptr.add(l_idx).write(v);
643                                }
644                            }
645                        }
646                    });
647
648                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
649
650                    unsafe { out_vec.set_len(n_values) };
651
652                    #[inline(never)]
653                    fn combine_validities_list_to_list_broadcast_right(
654                        offsets_lhs: &OffsetsBuffer<i64>,
655                        validity_lhs: Option<&Bitmap>,
656                        validity_rhs: Option<&Bitmap>,
657                        len_lhs: usize,
658                        width: usize,
659                        rhs_start: usize,
660                    ) -> Option<Bitmap> {
661                        match (validity_lhs, validity_rhs) {
662                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
663                            (Some(v), None) => return Some(v.clone()),
664                            (None, Some(v)) => {
665                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
666                            },
667                            (None, None) => None,
668                        }
669                        .map(|(mut validity_out, validity_rhs)| {
670                            for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
671                                let len: usize = lhs_len.min(width);
672
673                                for i in 0..len {
674                                    let l_idx = i + lhs_start;
675                                    let r_idx = i + rhs_start;
676
677                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
678                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
679                                    let is_valid = l_valid & r_valid;
680
681                                    // Size and alignment of validity vec are based on LHS.
682                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
683                                }
684                            }
685
686                            validity_out.freeze()
687                        })
688                    }
689
690                    let leaf_validity = combine_validities_list_to_list_broadcast_right(
691                        offsets_lhs,
692                        arr_lhs.validity(),
693                        arr_rhs.validity(),
694                        arr_lhs.len(),
695                        width,
696                        rhs_start,
697                    );
698
699                    let arr =
700                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
701
702                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
703                    assert_eq!(offsets.len(), 1);
704
705                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
706                },
707                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
708                    if self.list_to_prim_lhs.is_none() =>
709                {
710                    let offsets_lhs = self.data_lhs.0.as_slice();
711
712                    // Notes
713                    // * Primitive indexing starts from 0
714                    // * Output is aligned to LHS array
715
716                    let n_values = arr_lhs.len();
717                    let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
718                    let out_ptr = out_vec.as_mut_ptr();
719
720                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
721                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
722                        {
723                            let r = unsafe { arr_rhs.value_unchecked(i) };
724                            for l_idx in l_range {
725                                unsafe {
726                                    let l = arr_lhs.value_unchecked(l_idx);
727                                    let v = $OP(l, r);
728                                    out_ptr.add(l_idx).write(v);
729                                }
730                            }
731                        }
732                    });
733
734                    unsafe { out_vec.set_len(n_values) }
735
736                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
737                        offsets_lhs,
738                        arr_lhs.validity(),
739                        arr_rhs.validity(),
740                        arr_lhs.len(),
741                    );
742
743                    let arr =
744                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
745
746                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
747                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
748                },
749                // If we are dispatched here, it means that the LHS array is a unique allocation created
750                // after a unit-length list column was broadcasted, so this codepath mutably stores the
751                // results back into the LHS array to save memory.
752                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
753                    let offsets_lhs = self.data_lhs.0.as_slice();
754
755                    let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
756                    let arr = arr
757                        .as_any_mut()
758                        .downcast_mut::<PrimitiveArray<T::Native>>()
759                        .unwrap();
760                    let mut arr_lhs = std::mem::take(arr);
761
762                    self.op.0.prepare_numeric_op_side_validities::<T>(
763                        &mut arr_lhs,
764                        &mut arr_rhs,
765                        self.swapped,
766                    );
767
768                    let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
769                    assert_eq!(arr_lhs_mut_slice.len(), n_values);
770
771                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
772                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
773                        {
774                            let r = unsafe { arr_rhs.value_unchecked(i) };
775                            for l_idx in l_range {
776                                unsafe {
777                                    let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
778                                    *l = $OP(*l, r);
779                                }
780                            }
781                        }
782                    });
783
784                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
785                        offsets_lhs,
786                        arr_lhs.validity(),
787                        arr_rhs.validity(),
788                        arr_lhs.len(),
789                    );
790
791                    let arr = arr_lhs.with_validity(leaf_validity);
792
793                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
794                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
795                },
796                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
797                    assert_eq!(arr_rhs.len(), 1);
798
799                    let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
800                        // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL.
801                        let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
802                        return Ok(self.finish_offsets_and_validities(
803                            Box::new(
804                                arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
805                                    false,
806                                    arr_lhs.len(),
807                                ))),
808                            ),
809                            offsets,
810                            validities,
811                        ));
812                    };
813
814                    let arr = self
815                        .op
816                        .0
817                        .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
818                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
819
820                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
821                },
822                v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
823                | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
824                | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
825                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
826                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
827                    if cfg!(debug_assertions) {
828                        panic!("operation was not re-written: {:?}", v)
829                    } else {
830                        unreachable!()
831                    }
832                },
833            };
834
835            Ok(out)
836        }
837
838        /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every
839        /// level.
840        fn finish_offsets_and_validities(
841            &mut self,
842            leaf_array: Box<dyn Array>,
843            offsets: Vec<OffsetsBuffer<i64>>,
844            validities: Vec<Option<Bitmap>>,
845        ) -> ListChunked {
846            assert!(!offsets.is_empty());
847            assert_eq!(offsets.len(), validities.len());
848            let mut results = leaf_array;
849
850            let mut iter = offsets.into_iter().zip(validities).rev();
851
852            while iter.len() > 1 {
853                let (offsets, validity) = iter.next().unwrap();
854                let dtype = LargeListArray::default_datatype(results.dtype().clone());
855                results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
856            }
857
858            // The combined outer validity is pre-computed during `try_new()`
859            let (offsets, _) = iter.next().unwrap();
860            let validity = std::mem::take(&mut self.outer_validity);
861            let dtype = LargeListArray::default_datatype(results.dtype().clone());
862            let results = LargeListArray::new(dtype, offsets, results, Some(validity));
863
864            ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
865        }
866
867        fn materialize_broadcasted_list(
868            side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
869            output_len: usize,
870            output_primitive_dtype: &DataType,
871        ) -> (Box<dyn Array>, usize) {
872            let s = &side_data.2;
873            assert_eq!(s.len(), 1);
874
875            let expected_n_values = {
876                let offsets = s.list_offsets_and_validities_recursive().0;
877                output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
878            };
879
880            let ca = s.list().unwrap();
881            // Remember to cast the leaf primitives to the supertype.
882            let ca = ca
883                .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
884                .unwrap();
885            assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data.
886            let ca = ca.new_from_index(0, output_len).rechunk();
887
888            let s = ca.into_series();
889
890            *side_data = {
891                let (a, b) = s.list_offsets_and_validities_recursive();
892                // `Series::default()`: This field in the tuple is no longer used.
893                (a, b, Series::default())
894            };
895
896            let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
897            assert_eq!(n_values, expected_n_values);
898
899            let mut s = s.get_leaf_array();
900            let v = unsafe { s.chunks_mut() };
901
902            assert_eq!(v.len(), 1);
903            (v.swap_remove(0), n_values)
904        }
905    }
906
907    /// Used in 2 places, so it's outside here.
908    #[inline(never)]
909    fn combine_validities_list_to_primitive_no_broadcast(
910        offsets_lhs: &[OffsetsBuffer<i64>],
911        validity_lhs: Option<&Bitmap>,
912        validity_rhs: Option<&Bitmap>,
913        len_lhs: usize,
914    ) -> Option<Bitmap> {
915        match (validity_lhs, validity_rhs) {
916            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
917            (Some(v), None) => return Some(v.clone()),
918            // Materialize a full-true validity to re-use the codepath, as we still
919            // need to spread the bits from the RHS to the correct positions.
920            (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
921            (None, None) => None,
922        }
923        .map(|(mut validity_out, validity_rhs)| {
924            for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
925                let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
926                for l_idx in l_range {
927                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
928                    let is_valid = l_valid & r_valid;
929
930                    // Size and alignment of validity vec are based on LHS.
931                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
932                }
933            }
934
935            validity_out.freeze()
936        })
937    }
938}